Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
mlperf_transformer_v0.7
Commits
9e8a8c05
Commit
9e8a8c05
authored
Oct 14, 2024
by
jerrrrry
Browse files
Initial commit
parents
Changes
209
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2353 additions
and
0 deletions
+2353
-0
implementations/pytorch/fairseq/data/data_utils.py
implementations/pytorch/fairseq/data/data_utils.py
+535
-0
implementations/pytorch/fairseq/data/dictionary.py
implementations/pytorch/fairseq/data/dictionary.py
+412
-0
implementations/pytorch/fairseq/data/fairseq_dataset.py
implementations/pytorch/fairseq/data/fairseq_dataset.py
+38
-0
implementations/pytorch/fairseq/data/indexed_dataset.py
implementations/pytorch/fairseq/data/indexed_dataset.py
+277
-0
implementations/pytorch/fairseq/data/language_pair_dataset.py
...ementations/pytorch/fairseq/data/language_pair_dataset.py
+234
-0
implementations/pytorch/fairseq/data/monolingual_dataset.py
implementations/pytorch/fairseq/data/monolingual_dataset.py
+81
-0
implementations/pytorch/fairseq/data/token_block_dataset.py
implementations/pytorch/fairseq/data/token_block_dataset.py
+93
-0
implementations/pytorch/fairseq/distributed_utils.py
implementations/pytorch/fairseq/distributed_utils.py
+92
-0
implementations/pytorch/fairseq/fp16_trainer.py
implementations/pytorch/fairseq/fp16_trainer.py
+455
-0
implementations/pytorch/fairseq/meters.py
implementations/pytorch/fairseq/meters.py
+73
-0
implementations/pytorch/fairseq/models/__init__.py
implementations/pytorch/fairseq/models/__init__.py
+63
-0
implementations/pytorch/fairseq/models/__pycache__/__init__.cpython-310.pyc
...torch/fairseq/models/__pycache__/__init__.cpython-310.pyc
+0
-0
implementations/pytorch/fairseq/models/__pycache__/composite_encoder.cpython-310.pyc
...rseq/models/__pycache__/composite_encoder.cpython-310.pyc
+0
-0
implementations/pytorch/fairseq/models/__pycache__/fairseq_decoder.cpython-310.pyc
...airseq/models/__pycache__/fairseq_decoder.cpython-310.pyc
+0
-0
implementations/pytorch/fairseq/models/__pycache__/fairseq_encoder.cpython-310.pyc
...airseq/models/__pycache__/fairseq_encoder.cpython-310.pyc
+0
-0
implementations/pytorch/fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-310.pyc
...s/__pycache__/fairseq_incremental_decoder.cpython-310.pyc
+0
-0
implementations/pytorch/fairseq/models/__pycache__/fairseq_model.cpython-310.pyc
.../fairseq/models/__pycache__/fairseq_model.cpython-310.pyc
+0
-0
implementations/pytorch/fairseq/models/__pycache__/fconv.cpython-310.pyc
.../pytorch/fairseq/models/__pycache__/fconv.cpython-310.pyc
+0
-0
implementations/pytorch/fairseq/models/__pycache__/fconv_self_att.cpython-310.pyc
...fairseq/models/__pycache__/fconv_self_att.cpython-310.pyc
+0
-0
implementations/pytorch/fairseq/models/__pycache__/lstm.cpython-310.pyc
...s/pytorch/fairseq/models/__pycache__/lstm.cpython-310.pyc
+0
-0
No files found.
implementations/pytorch/fairseq/data/data_utils.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
contextlib
import
itertools
import
math
import
os
import
statistics
import
time
import
numpy
as
np
import
torch
from
.
import
FairseqDataset
import
fairseq.data.batch_C_v0p5
import
fairseq.data.batch_C_v0p5_better
import
fairseq.data.batch_C_v0p6
import
sys
def
infer_language_pair
(
path
):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src
,
dst
=
None
,
None
print
(
'Infer language pair from filename...'
)
for
filename
in
os
.
listdir
(
path
):
print
(
'filename:'
,
filename
)
parts
=
filename
.
split
(
'.'
)
if
len
(
parts
)
>=
3
and
len
(
parts
[
1
].
split
(
'-'
))
==
2
:
return
parts
[
1
].
split
(
'-'
)
return
src
,
dst
class
ShardedIterator
(
object
):
"""A sharded wrapper around an iterable (padded to length)."""
def
__init__
(
self
,
iterable
,
num_shards
,
shard_id
,
fill_value
=
None
):
if
shard_id
<
0
or
shard_id
>=
num_shards
:
raise
ValueError
(
'shard_id must be between 0 and num_shards'
)
self
.
_sharded_len
=
len
(
iterable
)
//
num_shards
if
len
(
iterable
)
%
num_shards
>
0
:
self
.
_sharded_len
+=
1
self
.
itr
=
itertools
.
zip_longest
(
range
(
self
.
_sharded_len
),
itertools
.
islice
(
iterable
,
shard_id
,
len
(
iterable
),
num_shards
),
fillvalue
=
fill_value
,
)
def
__len__
(
self
):
return
self
.
_sharded_len
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
return
next
(
self
.
itr
)[
1
]
class
CountingIterator
(
object
):
"""Wrapper around an iterable that maintains the iteration count."""
def
__init__
(
self
,
iterable
):
self
.
iterable
=
iterable
self
.
count
=
0
self
.
itr
=
iter
(
self
)
def
__len__
(
self
):
return
len
(
self
.
iterable
)
def
__iter__
(
self
):
for
x
in
self
.
iterable
:
self
.
count
+=
1
yield
x
def
__next__
(
self
):
return
next
(
self
.
itr
)
def
has_next
(
self
):
return
self
.
count
<
len
(
self
)
def
skip
(
self
,
num_to_skip
):
next
(
itertools
.
islice
(
self
.
itr
,
num_to_skip
,
num_to_skip
),
None
)
return
self
def
collate_tokens
(
values
,
pad_idx
,
eos_idx
,
left_pad
,
move_eos_to_beginning
=
False
,
n_seq_per_batch_multiple
=
8
,
seq_len_multiple
=
1
):
""" Convert a list of 1d tensors into a padded 2d tensor.
Args:
values: Python list where each element is a PyT 1d tensor
pad_idx: The index into the translation dictionary for the pad token (typically refer to 'dict.pad()')
eos_idx: The index into the translation dictionary for the eos token (typically refer to 'dict.eos()')
left_pad: Bool, left- or right-padding (true: left, false: right)
move_eos_to_beginning: Reverse order of sequence of tokens (true: reverse, false:leave in original order)
n_seq_per_batch_multiple: The number of sequences per batch to round down to
seq_len_multiple: The number of tokens per sequence to round up to
"""
size_of_seq_dim
=
max
(
v
.
size
(
0
)
for
v
in
values
)
# Unpadded size
n_seq_in_batch
=
len
(
values
)
if
n_seq_per_batch_multiple
%
seq_len_multiple
==
0
:
n_seq_multiple
=
n_seq_per_batch_multiple
/
seq_len_multiple
else
:
n_seq_multiple
=
n_seq_per_batch_multiple
if
n_seq_in_batch
<
n_seq_multiple
or
n_seq_in_batch
%
n_seq_multiple
>
0
:
seq_len_multiple
=
n_seq_per_batch_multiple
size_of_seq_dim
=
(
size_of_seq_dim
+
seq_len_multiple
-
1
)
//
seq_len_multiple
*
seq_len_multiple
# Padded seq len, rounded up to next multiple
padded_2d_tensor
=
values
[
0
].
new
(
len
(
values
),
size_of_seq_dim
).
fill_
(
pad_idx
)
def
copy_tensor
(
src
,
dst
):
assert
dst
.
numel
()
==
src
.
numel
()
if
move_eos_to_beginning
:
assert
src
[
-
1
]
==
eos_idx
dst
[
0
]
=
eos_idx
dst
[
1
:]
=
src
[:
-
1
]
else
:
dst
.
copy_
(
src
)
if
left_pad
:
for
idx
,
val
in
enumerate
(
values
):
copy_tensor
(
val
,
padded_2d_tensor
[
idx
][
size_of_seq_dim
-
len
(
val
):])
else
:
for
idx
,
val
in
enumerate
(
values
):
copy_tensor
(
val
,
padded_2d_tensor
[
idx
][:
len
(
val
)])
return
padded_2d_tensor
class
EpochBatchIterator
(
object
):
"""Iterate over a FairseqDataset and yield batches bucketed by size.
Batches may contain sequences of different lengths. This iterator can be
reused across multiple epochs with the next_epoch_itr() method.
Args:
dataset: a FairseqDataset
max_tokens: max number of tokens in each batch
max_sentences: max number of sentences in each batch
max_positions: max sentence length supported by the model
ignore_invalid_inputs: don't raise Exception for sentences that are too long
required_batch_size_multiple: require batch size to be a multiple of N
seeds: seeds for random number generator for reproducibility (1 seed for
each training epoch)
num_shards: shard the data iterator into N shards
shard_id: which shard of the data iterator to return
"""
def
__init__
(
self
,
dataset
,
dataloader_num_workers
=
1
,
dataloader_pin_memory
=
False
,
max_tokens
=
None
,
max_sentences
=
None
,
max_positions
=
None
,
ignore_invalid_inputs
=
False
,
required_batch_size_multiple
=
1
,
seeds
=
[
1
],
num_shards
=
1
,
shard_id
=
0
,
epoch
=
0
,
bucket_growth_factor
=
1.1
,
seq_len_multiple
=
1
,
batching_scheme
=
'v0p5'
,
batch_multiple_strategy
=
'multiple_of_sequences'
,
):
assert
isinstance
(
dataset
,
FairseqDataset
)
self
.
dataset
=
dataset
self
.
max_tokens
=
max_tokens
if
max_tokens
is
not
None
else
float
(
'Inf'
)
self
.
max_sentences
=
max_sentences
if
max_sentences
is
not
None
else
float
(
'Inf'
)
self
.
dataloader_num_workers
=
dataloader_num_workers
self
.
dataloader_pin_memory
=
dataloader_pin_memory
assert
len
(
max_positions
)
==
2
,
"Max positions contains source and target lengths!"
max_src_pos
,
max_tgt_pos
=
max_positions
self
.
max_positions
=
max_positions
self
.
max_positions_num
=
min
(
max_src_pos
,
max_tgt_pos
)
self
.
ignore_invalid_inputs
=
ignore_invalid_inputs
self
.
bsz_mult
=
required_batch_size_multiple
self
.
seeds
=
seeds
self
.
num_shards
=
num_shards
self
.
shard_id
=
shard_id
self
.
seq_len_multiple
=
seq_len_multiple
self
.
batching_scheme
=
batching_scheme
self
.
batch_multiple_strategy
=
batch_multiple_strategy
self
.
epoch
=
epoch
self
.
_cur_epoch_itr
=
None
self
.
_next_epoch_itr
=
None
with
numpy_seed
(
self
.
seeds
[
0
]):
import
time
start
=
time
.
time
()
indices
=
self
.
dataset
.
ordered_indices
(
self
.
seeds
[
self
.
epoch
])
#need integer, rather than float('Inf') values
max_sentences
=
max_sentences
if
max_sentences
is
not
None
else
sys
.
maxsize
max_tokens
=
max_tokens
if
max_tokens
is
not
None
else
sys
.
maxsize
if
self
.
batching_scheme
==
'v0p5'
:
batches
=
fairseq
.
data
.
batch_C_v0p5
.
make_batches_v0p5
(
self
.
dataset
.
src_sizes
,
self
.
dataset
.
tgt_sizes
,
indices
,
max_tokens
,
max_sentences
,
self
.
bsz_mult
,
self
.
max_positions_num
)
elif
self
.
batching_scheme
==
'v0p5_better'
:
print
(
'self.dataset.src_sizes'
,
self
.
dataset
.
src_sizes
.
size
)
print
(
'self.dataset.tgt_sizes'
,
self
.
dataset
.
tgt_sizes
.
size
)
batches
=
fairseq
.
data
.
batch_C_v0p5_better
.
make_batches_v0p5_better
(
self
.
dataset
.
src_sizes
,
self
.
dataset
.
tgt_sizes
,
indices
,
max_tokens
,
max_sentences
,
self
.
max_positions_num
,
self
.
bsz_mult
,
self
.
seq_len_multiple
)
elif
self
.
batching_scheme
==
'v0p6'
:
batch_strategy
=
2
if
self
.
batch_multiple_strategy
==
'mult_of_sequences'
:
batch_strategy
=
0
elif
self
.
batch_multiple_strategy
==
'pad_sequence_to_mult'
:
batch_strategy
=
1
elif
self
.
batch_multiple_strategy
==
'dynamic'
:
batch_strategy
=
2
else
:
assert
False
,
"Unknown batch multiple strategy!"
bucket_specify_min_boundary
=
8
use_efficient_last_pack
=
False
#batch_strategy = 2
batches
=
fairseq
.
data
.
batch_C_v0p6
.
make_batches_v0p6
(
self
.
dataset
.
src_sizes
,
self
.
dataset
.
tgt_sizes
,
indices
,
max_tokens
,
max_sentences
,
self
.
bsz_mult
,
self
.
max_positions_num
,
bucket_specify_min_boundary
,
bucket_growth_factor
,
batch_strategy
,
use_efficient_last_pack
)
else
:
# reference
def
roundup
(
x
,
multiple
):
return
(
x
+
multiple
-
1
)
//
multiple
*
multiple
def
rounddown
(
x
,
multiple
):
return
x
//
multiple
*
multiple
def
create_bucket_bounds_lists
(
max_allowable_seq_length
,
bucket_specify_min_boundary
,
bucket_specify_growth_scale
):
bucket_boundaries
=
[]
x
=
bucket_specify_min_boundary
while
x
<
max_allowable_seq_length
:
bucket_boundaries
.
append
(
x
)
x
=
max
(
x
+
1
,
int
(
x
*
bucket_specify_growth_scale
))
if
use_efficient_last_pack
:
buckets_min_list
=
[
0
]
+
[
i
+
1
for
i
in
bucket_boundaries
]
buckets_max_list
=
bucket_boundaries
+
[
max_allowable_seq_length
]
else
:
buckets_min_list
=
[
0
]
+
bucket_boundaries
buckets_max_list
=
bucket_boundaries
+
[
max_allowable_seq_length
+
1
]
return
buckets_min_list
,
buckets_max_list
def
create_seq_to_bucket_id_list_and_n_seq_per_batch
(
n_tok_per_seq
,
max_allowable_seq_length
,
max_sentences
,
pad_seq_per_batch_to_multiple_of
,
pad_tok_per_seq_to_multiple_of
,
bucket_specify_min_boundary
,
bucket_specify_growth_scale
):
bucket_interval_min
,
bucket_interval_max
=
create_bucket_bounds_lists
(
max_allowable_seq_length
,
bucket_specify_min_boundary
,
bucket_specify_growth_scale
)
if
do_seq_len_padding_to_multiple
:
n_seq_per_batch
=
[
max_tokens
//
roundup
(
x
,
pad_tok_per_seq_to_multiple_of
)
for
x
in
bucket_interval_max
]
elif
do_batch_size_rounding_down_to_multiple
:
n_seq_per_batch
=
[
rounddown
(
max_tokens
//
x
,
pad_seq_per_batch_to_multiple_of
)
for
x
in
bucket_interval_max
]
elif
do_dynamic_batch_size_choice
:
n_seq_per_batch_based_on_seq_len
=
[
max_tokens
//
roundup
(
x
,
pad_tok_per_seq_to_multiple_of
)
for
x
in
bucket_interval_max
]
n_seq_per_batch_based_on_n_seq
=
[
rounddown
(
max_tokens
//
x
,
pad_seq_per_batch_to_multiple_of
)
for
x
in
bucket_interval_max
]
n_seq_per_batch
=
[
max
(
a
,
b
)
for
a
,
b
in
zip
(
n_seq_per_batch_based_on_seq_len
,
n_seq_per_batch_based_on_n_seq
)]
else
:
n_seq_per_batch
=
[
max_tokens
//
x
for
x
in
bucket_interval_max
]
n_seq_per_batch
=
[
min
(
max_sentences
,
i
)
if
max_sentences
is
not
None
else
i
for
i
in
n_seq_per_batch
]
for
a
,
b
,
c
in
zip
(
bucket_interval_min
,
bucket_interval_max
,
n_seq_per_batch
):
print
(
'bucket:'
,
a
,
b
,
c
)
token_length_2_bucket_id
=
{}
for
x
in
range
(
max_allowable_seq_length
+
1
):
for
bucket_id
,
payload
in
enumerate
(
zip
(
bucket_interval_min
,
bucket_interval_max
)):
bmin
,
bmax
=
payload
if
(
bmin
<=
x
and
x
<=
bmax
and
use_efficient_last_pack
)
or
(
bmin
<=
x
and
x
<
bmax
):
token_length_2_bucket_id
[
x
]
=
bucket_id
break
return
([
token_length_2_bucket_id
[
x
]
if
x
<=
max_allowable_seq_length
else
-
1
for
x
in
n_tok_per_seq
],
n_seq_per_batch
,
len
(
bucket_interval_min
))
# Make adjustments to tuneable parameters here
pad_seq_per_batch_to_multiple_of
=
self
.
bsz_mult
pad_tok_per_seq_to_multiple_of
=
self
.
bsz_mult
max_allowable_seq_length
=
self
.
max_positions_num
bucket_specify_min_boundary
=
8
bucket_specify_growth_scale
=
bucket_growth_factor
##1.035
do_seq_len_padding_to_multiple
=
False
do_batch_size_rounding_down_to_multiple
=
False
do_dynamic_batch_size_choice
=
True
use_efficient_last_pack
=
False
batches
=
[]
src_token_counts
=
[]
dst_token_counts
=
[]
seq_counts
=
[]
padded_token_counts
=
[]
batch_max_padded_seq_len
=
0
batch_seq_count
=
0
batches
.
append
([])
src_batch_token_count
=
0
dst_batch_token_count
=
0
curr_batch_padded_token_count
=
0
batch_n_seq
=
0
bucket_id
=
0
longest_in_batch
=
[]
print
(
'### max_tokens:'
,
max_tokens
)
print
(
'### max_sentences:'
,
max_sentences
)
pairwise_max_seq_len
=
[
max
(
a
,
b
)
for
a
,
b
in
zip
(
dataset
.
src_sizes
,
dataset
.
tgt_sizes
)]
bucket_ids
,
n_seq_per_batch
,
n_buckets
=
create_seq_to_bucket_id_list_and_n_seq_per_batch
(
pairwise_max_seq_len
,
max_allowable_seq_length
,
max_sentences
,
pad_seq_per_batch_to_multiple_of
,
pad_tok_per_seq_to_multiple_of
,
bucket_specify_min_boundary
,
bucket_specify_growth_scale
)
buckets
=
[]
for
i
in
range
(
n_buckets
):
buckets
.
append
([])
n_rejected_sequences
=
0
for
idx
,
bidx
in
enumerate
(
bucket_ids
):
if
bidx
>=
0
:
buckets
[
bidx
].
append
(
idx
)
else
:
n_rejected_sequences
+=
1
# Remove empty buckets (causes blow-up in eval code).
buckets
=
[
i
for
i
in
buckets
if
len
(
i
)
>
0
]
print
(
n_rejected_sequences
,
'were omitted due to containing over 256 tokens.'
)
batch_seq_count
=
0
#count = 0
seq_len_tracker
=
0
for
bucket
,
nspb
in
zip
(
buckets
,
n_seq_per_batch
):
for
item
in
bucket
:
if
batch_n_seq
<
nspb
:
batches
[
-
1
].
append
(
item
)
src_batch_token_count
+=
dataset
.
src_sizes
[
item
]
dst_batch_token_count
+=
dataset
.
tgt_sizes
[
item
]
seq_len_tracker
=
max
(
seq_len_tracker
,
dst_batch_token_count
)
batch_n_seq
+=
1
else
:
batches
.
append
([
item
])
src_token_counts
.
append
(
src_batch_token_count
)
dst_token_counts
.
append
(
dst_batch_token_count
)
src_batch_token_count
=
dataset
.
src_sizes
[
item
]
dst_batch_token_count
=
dataset
.
tgt_sizes
[
item
]
seq_counts
.
append
(
batch_n_seq
)
batch_n_seq
=
1
batches
.
append
([])
batch_n_seq
=
0
seq_counts
.
append
(
batch_n_seq
)
src_batch_token_count
=
0
dst_batch_token_count
=
0
src_token_counts
.
append
(
src_batch_token_count
)
dst_token_counts
.
append
(
dst_batch_token_count
)
seq_cnt2
=
[]
for
batch
in
batches
:
seq_len_tracker
=
0
nseqbucket
=
0
for
item
in
batch
:
a
=
dataset
.
src_sizes
[
item
]
b
=
dataset
.
tgt_sizes
[
item
]
seq_len_tracker
=
max
(
seq_len_tracker
,
max
(
a
,
b
))
nseqbucket
+=
1
longest_in_batch
.
append
(
seq_len_tracker
)
seq_cnt2
.
append
(
nseqbucket
)
# In the unlucky case, remove a newly created but empty last batch
if
not
batches
[
-
1
]:
del
batches
[
-
1
]
del
seq_counts
[
-
1
]
del
src_token_counts
[
-
1
]
del
dst_token_counts
[
-
1
]
tmp_batches
=
batches
batches
=
[]
for
b
in
tmp_batches
:
if
b
:
batches
.
append
(
b
)
#padded_token_counts = src_token_counts
#padded_token_counts = [x*0 for x in src_token_counts] # Setting to zero until this is actually implemented
#print('split dataset length:', len(dataset.src))
#print('mean src tokens per batch =', statistics.mean(src_token_counts), statistics.mean(padded_token_counts))
#print('median src tokens per batch =', statistics.median(src_token_counts), statistics.median(padded_token_counts))
#print('stdev src tokens per batch =', statistics.stdev(src_token_counts), statistics.stdev(padded_token_counts))
#print('min src tokens per batch =', min(src_token_counts), min(padded_token_counts))
#print('max src tokens per batch =', max(src_token_counts), max(padded_token_counts))
#print('mean tgt tokens per batch =', statistics.mean(dst_token_counts), statistics.mean(padded_token_counts))
#print('median tgt tokens per batch =', statistics.median(dst_token_counts), statistics.mean(padded_token_counts))
#print('stdev tgt tokens per batch =', statistics.stdev(dst_token_counts), statistics.stdev(padded_token_counts))
#print('min tgt tokens per batch =', min(dst_token_counts), min(padded_token_counts))
#print('max tgt tokens per batch =', max(dst_token_counts), max(padded_token_counts))
#print('mean seq per batch =', statistics.mean(seq_counts), statistics.mean(padded_token_counts))
#print('median seq per batch =', statistics.median(seq_counts), statistics.median(padded_token_counts))
#print('stdev seq per batch =', statistics.stdev(seq_counts), statistics.stdev(padded_token_counts))
#print('min seq per batch =', min(seq_counts), min(padded_token_counts))
#print('max seq per batch =', max(seq_counts), max(padded_token_counts))
#print('pad inc: mean tgt tokens per batch =', statistics.mean(np.array(seq_cnt2) * np.array(longest_in_batch)), longest_in_batch[:3], seq_cnt2[:3])
#print('pad inc: median tgt tokens per batch =', statistics.median(np.array(seq_cnt2) * np.array(longest_in_batch)), longest_in_batch[:3], seq_cnt2[:3])
self
.
frozen_batches
=
tuple
(
batches
)
# self.frozen_batches = tuple(self._batch_generator())
print
(
"generated %d batches in %fs"
%
(
len
(
batches
),
time
.
time
()
-
start
))
def
__len__
(
self
):
return
len
(
self
.
frozen_batches
)
def
next_epoch_itr
(
self
,
shuffle
=
True
):
"""Shuffle batches and return a new iterator over the dataset."""
if
self
.
_next_epoch_itr
is
not
None
:
self
.
_cur_epoch_itr
=
self
.
_next_epoch_itr
self
.
_next_epoch_itr
=
None
else
:
self
.
epoch
+=
1
self
.
_cur_epoch_itr
=
self
.
_get_iterator_for_epoch
(
self
.
epoch
,
shuffle
)
return
self
.
_cur_epoch_itr
def
end_of_epoch
(
self
):
return
not
self
.
_cur_epoch_itr
.
has_next
()
@
property
def
iterations_in_epoch
(
self
):
if
self
.
_cur_epoch_itr
is
not
None
:
return
self
.
_cur_epoch_itr
.
count
elif
self
.
_next_epoch_itr
is
not
None
:
return
self
.
_next_epoch_itr
.
count
return
0
def
state_dict
(
self
):
return
{
'epoch'
:
self
.
epoch
,
'iterations_in_epoch'
:
self
.
iterations_in_epoch
,
}
def
load_state_dict
(
self
,
state_dict
):
self
.
epoch
=
state_dict
[
'epoch'
]
itr_pos
=
state_dict
.
get
(
'iterations_in_epoch'
,
0
)
if
itr_pos
>
0
:
# fast-forward epoch iterator
itr
=
self
.
_get_iterator_for_epoch
(
self
.
epoch
,
state_dict
.
get
(
'shuffle'
,
True
))
if
itr_pos
<
len
(
itr
):
self
.
_next_epoch_itr
=
itr
.
skip
(
itr_pos
)
def
_get_iterator_for_epoch
(
self
,
epoch
,
shuffle
):
if
shuffle
:
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with
numpy_seed
(
self
.
seeds
[
epoch
]):
batches
=
list
(
self
.
frozen_batches
)
# copy
np
.
random
.
shuffle
(
batches
)
else
:
batches
=
self
.
frozen_batches
return
CountingIterator
(
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset
,
collate_fn
=
self
.
dataset
.
collater
,
num_workers
=
self
.
dataloader_num_workers
,
pin_memory
=
self
.
dataloader_pin_memory
,
batch_sampler
=
ShardedIterator
(
batches
,
self
.
num_shards
,
self
.
shard_id
,
fill_value
=
[]),
))
def
_batch_generator
(
self
):
batch
=
[]
def
is_batch_full
(
num_tokens
):
if
len
(
batch
)
==
0
:
return
False
if
len
(
batch
)
==
self
.
max_sentences
:
return
True
if
num_tokens
>
self
.
max_tokens
:
return
True
return
False
sample_len
=
0
sample_lens
=
[]
ignored
=
[]
for
idx
in
self
.
dataset
.
ordered_indices
(
self
.
seeds
[
self
.
epoch
]):
if
not
self
.
dataset
.
valid_size
(
idx
,
self
.
max_positions
):
if
self
.
ignore_invalid_inputs
:
ignored
.
append
(
idx
)
continue
raise
Exception
((
'Size of sample #{} is invalid, max_positions={}, skip this example with --skip-invalid-size-inputs-valid-test'
).
format
(
idx
,
self
.
max_positions
))
sample_lens
.
append
(
self
.
dataset
.
num_tokens
(
idx
))
sample_len
=
max
(
sample_len
,
sample_lens
[
-
1
])
num_tokens
=
(
len
(
batch
)
+
1
)
*
sample_len
if
is_batch_full
(
num_tokens
):
mod_len
=
max
(
self
.
bsz_mult
*
(
len
(
batch
)
//
self
.
bsz_mult
),
len
(
batch
)
%
self
.
bsz_mult
,)
yield
batch
[:
mod_len
]
batch
=
batch
[
mod_len
:]
sample_lens
=
sample_lens
[
mod_len
:]
sample_len
=
max
(
sample_lens
)
if
len
(
sample_lens
)
>
0
else
0
batch
.
append
(
idx
)
if
len
(
batch
)
>
0
:
yield
batch
if
len
(
ignored
)
>
0
:
print
((
'| WARNING: {} samples have invalid sizes and will be skipped, max_positions={}, first few sample ids={}'
).
format
(
len
(
ignored
),
self
.
max_positions
,
ignored
[:
10
]))
@
contextlib
.
contextmanager
def
numpy_seed
(
seed
):
"""Context manager which seeds the NumPy PRNG with the specified seed and restores the state afterward"""
if
seed
is
None
:
yield
return
state
=
np
.
random
.
get_state
()
np
.
random
.
seed
(
seed
)
try
:
yield
finally
:
np
.
random
.
set_state
(
state
)
implementations/pytorch/fairseq/data/dictionary.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
collections
import
Counter
import
os
import
torch
# MLPerf compliant dictionary
class
Dictionary
(
object
):
"""A mapping from symbols to consecutive integers"""
def
__init__
(
self
,
pad
=
'<pad>_'
,
eos
=
'<EOS>_'
):
self
.
pad_word
,
self
.
eos_word
=
pad
,
eos
self
.
symbols
=
[]
self
.
count
=
[]
self
.
indices
=
{}
# dictionary indexing starts at 1 for consistency with Lua
# Commented out and hard-coded since pad and eos are in the dictionary files already
self
.
add_symbol
(
'<lua_index_compat>'
)
self
.
pad_index
=
1
self
.
eos_index
=
2
#self.pad_index = self.add_symbol(pad)
#self.eos_index = self.add_symbol(eos)
#self.add_symbol('<bypass_unk>')
self
.
nspecial
=
3
def
__eq__
(
self
,
other
):
return
self
.
indices
==
other
.
indices
def
__getitem__
(
self
,
idx
):
if
idx
<
len
(
self
.
symbols
):
return
self
.
symbols
[
idx
]
else
:
assert
idx
<
len
(
self
.
symbols
)
def
__len__
(
self
):
"""Returns the number of symbols in the dictionary"""
return
len
(
self
.
symbols
)
def
index
(
self
,
sym
):
"""Returns the index of the specified symbol"""
if
sym
in
self
.
indices
:
return
self
.
indices
[
sym
]
else
:
assert
sym
in
self
.
indices
def
string
(
self
,
tensor
,
bpe_symbol
=
None
):
"""Helper for converting a tensor of token indices to a string.
Can optionally remove BPE symbols or escape <unk> words.
"""
if
torch
.
is_tensor
(
tensor
)
and
tensor
.
dim
()
==
2
:
return
'
\n
'
.
join
(
self
.
string
(
t
)
for
t
in
tensor
)
def
token_string
(
i
):
return
self
[
i
]
sent
=
' '
.
join
(
token_string
(
i
)
for
i
in
tensor
if
i
!=
self
.
eos
())
if
bpe_symbol
is
not
None
:
sent
=
(
sent
+
' '
).
replace
(
bpe_symbol
,
''
).
rstrip
()
return
sent
def
add_symbol
(
self
,
word
,
n
=
1
):
"""Adds a word to the dictionary"""
if
word
in
self
.
indices
:
idx
=
self
.
indices
[
word
]
self
.
count
[
idx
]
=
self
.
count
[
idx
]
+
n
return
idx
else
:
idx
=
len
(
self
.
symbols
)
self
.
indices
[
word
]
=
idx
self
.
symbols
.
append
(
word
)
self
.
count
.
append
(
n
)
return
idx
def
update
(
self
,
new_dict
):
"""Updates counts from new dictionary."""
for
word
in
new_dict
.
symbols
:
idx2
=
new_dict
.
indices
[
word
]
if
word
in
self
.
indices
:
idx
=
self
.
indices
[
word
]
self
.
count
[
idx
]
=
self
.
count
[
idx
]
+
new_dict
.
count
[
idx2
]
else
:
idx
=
len
(
self
.
symbols
)
self
.
indices
[
word
]
=
idx
self
.
symbols
.
append
(
word
)
self
.
count
.
append
(
new_dict
.
count
[
idx2
])
def
finalize
(
self
,
threshold
=-
1
,
nwords
=-
1
,
padding_factor
=
8
):
"""Sort symbols by frequency in descending order, ignoring special ones.
Args:
- threshold defines the minimum word count
- nwords defines the total number of words in the final dictionary,
including special symbols
- padding_factor can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
if
nwords
<=
0
:
nwords
=
len
(
self
)
new_indices
=
dict
(
zip
(
self
.
symbols
[:
self
.
nspecial
],
range
(
self
.
nspecial
)))
new_symbols
=
self
.
symbols
[:
self
.
nspecial
]
new_count
=
self
.
count
[:
self
.
nspecial
]
c
=
Counter
(
dict
(
zip
(
self
.
symbols
[
self
.
nspecial
:],
self
.
count
[
self
.
nspecial
:])))
for
symbol
,
count
in
c
.
most_common
(
nwords
-
self
.
nspecial
):
if
count
>=
threshold
:
new_indices
[
symbol
]
=
len
(
new_symbols
)
new_symbols
.
append
(
symbol
)
new_count
.
append
(
count
)
else
:
break
threshold_nwords
=
len
(
new_symbols
)
if
padding_factor
>
1
:
i
=
0
while
threshold_nwords
%
padding_factor
!=
0
:
symbol
=
'madeupword{:04d}'
.
format
(
i
)
new_indices
[
symbol
]
=
len
(
new_symbols
)
new_symbols
.
append
(
symbol
)
new_count
.
append
(
0
)
i
+=
1
threshold_nwords
+=
1
assert
len
(
new_symbols
)
%
padding_factor
==
0
assert
len
(
new_symbols
)
==
len
(
new_indices
)
self
.
count
=
list
(
new_count
)
self
.
symbols
=
list
(
new_symbols
)
self
.
indices
=
new_indices
def
pad
(
self
):
"""Helper to get index of pad symbol"""
return
self
.
pad_index
def
eos
(
self
):
"""Helper to get index of end-of-sentence symbol"""
return
self
.
eos_index
@
classmethod
def
load
(
cls
,
f
,
ignore_utf_errors
=
False
):
"""Loads the dictionary from a text file with the format:
```
<symbol0>
<symbol1>
...
```
"""
if
isinstance
(
f
,
str
):
try
:
if
not
ignore_utf_errors
:
with
open
(
f
,
'r'
,
encoding
=
'utf-8'
)
as
fd
:
return
cls
.
load
(
fd
)
else
:
with
open
(
f
,
'r'
,
encoding
=
'utf-8'
,
errors
=
'ignore'
)
as
fd
:
return
cls
.
load
(
fd
)
except
FileNotFoundError
as
fnfe
:
raise
fnfe
except
Exception
:
raise
Exception
(
"Incorrect encoding detected in {}, please rebuild the dataset"
.
format
(
f
))
d
=
cls
()
for
line
in
f
.
readlines
():
word
=
line
.
strip
()[
1
:
-
1
]
## Remove the single quotes
count
=
1
d
.
indices
[
word
]
=
len
(
d
.
symbols
)
d
.
symbols
.
append
(
word
)
d
.
count
.
append
(
count
)
n_pad_tokens_on_end
=
33712
-
len
(
d
.
symbols
)
#assert n_pad_tokens_on_end == 3 ## DEBUG: remove later, sanity check
for
i
in
range
(
n_pad_tokens_on_end
):
pad_str
=
'<pad000'
+
str
(
i
)
+
'>'
d
.
indices
[
pad_str
]
=
len
(
d
.
symbols
)
d
.
symbols
.
append
(
pad_str
)
d
.
count
.
append
(
1
)
return
d
def
save
(
self
,
f
):
"""Stores dictionary into a text file"""
if
isinstance
(
f
,
str
):
os
.
makedirs
(
os
.
path
.
dirname
(
f
),
exist_ok
=
True
)
with
open
(
f
,
'w'
,
encoding
=
'utf-8'
)
as
fd
:
return
self
.
save
(
fd
)
for
symbol
,
count
in
zip
(
self
.
symbols
[
self
.
nspecial
:],
self
.
count
[
self
.
nspecial
:]):
print
(
'{} {}'
.
format
(
symbol
,
count
),
file
=
f
)
def
dummy_sentence
(
self
,
length
):
t
=
torch
.
Tensor
(
length
).
uniform_
(
self
.
nspecial
+
1
,
len
(
self
)).
long
()
t
[
-
1
]
=
self
.
eos
()
return
t
class
Dictionary_fairseq
(
object
):
"""A mapping from symbols to consecutive integers"""
def
__init__
(
self
,
pad
=
'<pad>'
,
eos
=
'</s>'
,
unk
=
'<unk>'
):
self
.
unk_word
,
self
.
pad_word
,
self
.
eos_word
=
unk
,
pad
,
eos
self
.
symbols
=
[]
self
.
count
=
[]
self
.
indices
=
{}
# dictionary indexing starts at 1 for consistency with Lua
self
.
add_symbol
(
'<Lua heritage>'
)
self
.
pad_index
=
self
.
add_symbol
(
pad
)
self
.
eos_index
=
self
.
add_symbol
(
eos
)
self
.
unk_index
=
self
.
add_symbol
(
unk
)
self
.
nspecial
=
len
(
self
.
symbols
)
def
__eq__
(
self
,
other
):
return
self
.
indices
==
other
.
indices
def
__getitem__
(
self
,
idx
):
if
idx
<
len
(
self
.
symbols
):
return
self
.
symbols
[
idx
]
return
self
.
unk_word
def
__len__
(
self
):
"""Returns the number of symbols in the dictionary"""
return
len
(
self
.
symbols
)
def
index
(
self
,
sym
):
"""Returns the index of the specified symbol"""
if
sym
in
self
.
indices
:
return
self
.
indices
[
sym
]
return
self
.
unk_index
def
string
(
self
,
tensor
,
bpe_symbol
=
None
,
escape_unk
=
False
):
"""Helper for converting a tensor of token indices to a string.
Can optionally remove BPE symbols or escape <unk> words.
"""
if
torch
.
is_tensor
(
tensor
)
and
tensor
.
dim
()
==
2
:
return
'
\n
'
.
join
(
self
.
string
(
t
)
for
t
in
tensor
)
def
token_string
(
i
):
if
i
==
self
.
unk
():
return
self
.
unk_string
(
escape_unk
)
else
:
return
self
[
i
]
sent
=
' '
.
join
(
token_string
(
i
)
for
i
in
tensor
if
i
!=
self
.
eos
())
if
bpe_symbol
is
not
None
:
sent
=
(
sent
+
' '
).
replace
(
bpe_symbol
,
''
).
rstrip
()
return
sent
def
unk_string
(
self
,
escape
=
False
):
"""Return unknown string, optionally escaped as: <<unk>>"""
if
escape
:
return
'<{}>'
.
format
(
self
.
unk_word
)
else
:
return
self
.
unk_word
def
add_symbol
(
self
,
word
,
n
=
1
):
"""Adds a word to the dictionary"""
if
word
in
self
.
indices
:
idx
=
self
.
indices
[
word
]
self
.
count
[
idx
]
=
self
.
count
[
idx
]
+
n
return
idx
else
:
idx
=
len
(
self
.
symbols
)
self
.
indices
[
word
]
=
idx
self
.
symbols
.
append
(
word
)
self
.
count
.
append
(
n
)
return
idx
def
update
(
self
,
new_dict
):
"""Updates counts from new dictionary."""
for
word
in
new_dict
.
symbols
:
idx2
=
new_dict
.
indices
[
word
]
if
word
in
self
.
indices
:
idx
=
self
.
indices
[
word
]
self
.
count
[
idx
]
=
self
.
count
[
idx
]
+
new_dict
.
count
[
idx2
]
else
:
idx
=
len
(
self
.
symbols
)
self
.
indices
[
word
]
=
idx
self
.
symbols
.
append
(
word
)
self
.
count
.
append
(
new_dict
.
count
[
idx2
])
def
finalize
(
self
,
threshold
=-
1
,
nwords
=-
1
,
padding_factor
=
8
):
"""Sort symbols by frequency in descending order, ignoring special ones.
Args:
- threshold defines the minimum word count
- nwords defines the total number of words in the final dictionary,
including special symbols
- padding_factor can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
if
nwords
<=
0
:
nwords
=
len
(
self
)
new_indices
=
dict
(
zip
(
self
.
symbols
[:
self
.
nspecial
],
range
(
self
.
nspecial
)))
new_symbols
=
self
.
symbols
[:
self
.
nspecial
]
new_count
=
self
.
count
[:
self
.
nspecial
]
c
=
Counter
(
dict
(
zip
(
self
.
symbols
[
self
.
nspecial
:],
self
.
count
[
self
.
nspecial
:])))
for
symbol
,
count
in
c
.
most_common
(
nwords
-
self
.
nspecial
):
if
count
>=
threshold
:
new_indices
[
symbol
]
=
len
(
new_symbols
)
new_symbols
.
append
(
symbol
)
new_count
.
append
(
count
)
else
:
break
threshold_nwords
=
len
(
new_symbols
)
if
padding_factor
>
1
:
i
=
0
while
threshold_nwords
%
padding_factor
!=
0
:
symbol
=
'madeupword{:04d}'
.
format
(
i
)
new_indices
[
symbol
]
=
len
(
new_symbols
)
new_symbols
.
append
(
symbol
)
new_count
.
append
(
0
)
i
+=
1
threshold_nwords
+=
1
assert
len
(
new_symbols
)
%
padding_factor
==
0
assert
len
(
new_symbols
)
==
len
(
new_indices
)
self
.
count
=
list
(
new_count
)
self
.
symbols
=
list
(
new_symbols
)
self
.
indices
=
new_indices
def
pad
(
self
):
"""Helper to get index of pad symbol"""
return
self
.
pad_index
def
eos
(
self
):
"""Helper to get index of end-of-sentence symbol"""
return
self
.
eos_index
def
unk
(
self
):
"""Helper to get index of unk symbol"""
return
self
.
unk_index
@
classmethod
def
load
(
cls
,
f
,
ignore_utf_errors
=
False
):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
if
isinstance
(
f
,
str
):
try
:
if
not
ignore_utf_errors
:
with
open
(
f
,
'r'
,
encoding
=
'utf-8'
)
as
fd
:
return
cls
.
load
(
fd
)
else
:
with
open
(
f
,
'r'
,
encoding
=
'utf-8'
,
errors
=
'ignore'
)
as
fd
:
return
cls
.
load
(
fd
)
except
FileNotFoundError
as
fnfe
:
raise
fnfe
except
Exception
:
raise
Exception
(
"Incorrect encoding detected in {}, please "
"rebuild the dataset"
.
format
(
f
))
d
=
cls
()
for
line
in
f
.
readlines
():
idx
=
line
.
rfind
(
' '
)
word
=
line
[:
idx
]
count
=
int
(
line
[
idx
+
1
:])
d
.
indices
[
word
]
=
len
(
d
.
symbols
)
d
.
symbols
.
append
(
word
)
d
.
count
.
append
(
count
)
return
d
def
save
(
self
,
f
):
"""Stores dictionary into a text file"""
if
isinstance
(
f
,
str
):
os
.
makedirs
(
os
.
path
.
dirname
(
f
),
exist_ok
=
True
)
with
open
(
f
,
'w'
,
encoding
=
'utf-8'
)
as
fd
:
return
self
.
save
(
fd
)
for
symbol
,
count
in
zip
(
self
.
symbols
[
self
.
nspecial
:],
self
.
count
[
self
.
nspecial
:]):
print
(
'{} {}'
.
format
(
symbol
,
count
),
file
=
f
)
def
dummy_sentence
(
self
,
length
):
t
=
torch
.
Tensor
(
length
).
uniform_
(
self
.
nspecial
+
1
,
len
(
self
)).
long
()
t
[
-
1
]
=
self
.
eos
()
return
t
implementations/pytorch/fairseq/data/fairseq_dataset.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch.utils.data
class
FairseqDataset
(
torch
.
utils
.
data
.
Dataset
):
"""A dataset that provides helpers for batching."""
def
__getitem__
(
self
,
index
):
raise
NotImplementedError
def
__len__
(
self
):
raise
NotImplementedError
def
collater
(
self
,
samples
):
"""Merge a list of samples to form a mini-batch."""
raise
NotImplementedError
def
get_dummy_batch
(
self
,
num_tokens
,
max_positions
):
"""Return a dummy batch with a given number of tokens."""
raise
NotImplementedError
def
num_tokens
(
self
,
index
):
"""Return an example's length (number of tokens), used for batching."""
raise
NotImplementedError
def
ordered_indices
(
self
,
seed
=
None
):
"""Ordered indices for batching."""
raise
NotImplementedError
def
valid_size
(
self
,
index
,
max_positions
):
"""Check if an example's size is valid according to max_positions."""
raise
NotImplementedError
implementations/pytorch/fairseq/data/indexed_dataset.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
os
import
struct
import
numpy
as
np
import
torch
from
fairseq.tokenizer
import
Tokenizer
def
read_longs
(
f
,
n
):
a
=
np
.
empty
(
n
,
dtype
=
np
.
int64
)
f
.
readinto
(
a
)
return
a
def
write_longs
(
f
,
a
):
f
.
write
(
np
.
array
(
a
,
dtype
=
np
.
int64
))
dtypes
=
{
1
:
np
.
uint8
,
2
:
np
.
int8
,
3
:
np
.
int16
,
4
:
np
.
int32
,
5
:
np
.
int64
,
6
:
np
.
float
,
7
:
np
.
double
,
}
def
code
(
dtype
):
for
k
in
dtypes
.
keys
():
if
dtypes
[
k
]
==
dtype
:
return
k
def
index_file_path
(
prefix_path
):
return
prefix_path
+
'.idx'
def
data_file_path
(
prefix_path
):
return
prefix_path
+
'.bin'
class
IndexedDataset
(
torch
.
utils
.
data
.
Dataset
):
"""Loader for TorchNet IndexedDataset"""
def
__init__
(
self
,
path
):
super
().
__init__
()
with
open
(
index_file_path
(
path
),
'rb'
)
as
f
:
magic
=
f
.
read
(
8
)
assert
magic
==
b
'TNTIDX
\x00\x00
'
version
=
f
.
read
(
8
)
assert
struct
.
unpack
(
'<Q'
,
version
)
==
(
1
,)
code
,
self
.
element_size
=
struct
.
unpack
(
'<QQ'
,
f
.
read
(
16
))
self
.
dtype
=
dtypes
[
code
]
self
.
size
,
self
.
s
=
struct
.
unpack
(
'<QQ'
,
f
.
read
(
16
))
self
.
dim_offsets
=
read_longs
(
f
,
self
.
size
+
1
)
self
.
data_offsets
=
read_longs
(
f
,
self
.
size
+
1
)
self
.
sizes
=
read_longs
(
f
,
self
.
s
)
self
.
read_data
(
path
)
def
read_data
(
self
,
path
):
self
.
data_file
=
open
(
data_file_path
(
path
),
'rb'
,
buffering
=
0
)
def
check_index
(
self
,
i
):
if
i
<
0
or
i
>=
self
.
size
:
raise
IndexError
(
'index out of range'
)
def
__del__
(
self
):
self
.
data_file
.
close
()
def
__getitem__
(
self
,
i
):
self
.
check_index
(
i
)
tensor_size
=
self
.
sizes
[
self
.
dim_offsets
[
i
]:
self
.
dim_offsets
[
i
+
1
]]
a
=
np
.
empty
(
tensor_size
,
dtype
=
self
.
dtype
)
self
.
data_file
.
seek
(
self
.
data_offsets
[
i
]
*
self
.
element_size
)
self
.
data_file
.
readinto
(
a
)
#a += 1 ## DEBUG: lua_index_compat
item
=
torch
.
from_numpy
(
a
).
long
()
return
item
def
__len__
(
self
):
return
self
.
size
@
staticmethod
def
exists
(
path
):
return
(
os
.
path
.
exists
(
index_file_path
(
path
))
and
os
.
path
.
exists
(
data_file_path
(
path
))
)
class
MockedInMemoryDataset
(
IndexedDataset
):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory"""
def
__init__
(
self
,
path
,
n_seq_pairs_in_mock_data
,
uniform_n_seq_per_batch
,
uniform_seq_len_per_batch
):
self
.
dtype
=
np
.
int64
self
.
uniform_n_seq_per_batch
=
uniform_n_seq_per_batch
self
.
uniform_seq_len_per_batch
=
uniform_seq_len_per_batch
self
.
size
=
n_seq_pairs_in_mock_data
self
.
sizes
=
[]
for
i
in
range
(
n_seq_pairs_in_mock_data
):
self
.
sizes
.
append
(
uniform_seq_len_per_batch
)
def
__del__
(
self
):
pass
def
__getitem__
(
self
,
i
):
self
.
check_index
(
i
)
arbitrary_token_id
=
55
# Just not a reserved token
a
=
np
.
ones
((
self
.
uniform_seq_len_per_batch
,),
dtype
=
self
.
dtype
)
*
arbitrary_token_id
a
[
-
1
]
=
2
# Manually add an <EOS>
#a[self.uniform_seq_len_per_batch-1] = 2 # Manually add an <EOS>
return
torch
.
from_numpy
(
a
).
long
()
class
IndexedInMemoryDataset
(
IndexedDataset
):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory"""
def
read_data
(
self
,
path
):
self
.
data_file
=
open
(
data_file_path
(
path
),
'rb'
)
self
.
buffer
=
np
.
empty
(
self
.
data_offsets
[
-
1
],
dtype
=
self
.
dtype
)
self
.
data_file
.
readinto
(
self
.
buffer
)
#print('buffer max:', np.max(self.buffer), np.min(self.buffer))
#self.buffer[self.buffer > 0] += 1 ## DEBUG
#self.buffer += 1 ## DEBUG
#print('buffer max after:', np.max(self.buffer), np.min(self.buffer))
self
.
data_file
.
close
()
def
__del__
(
self
):
pass
def
__getitem__
(
self
,
i
):
self
.
check_index
(
i
)
tensor_size
=
self
.
sizes
[
self
.
dim_offsets
[
i
]:
self
.
dim_offsets
[
i
+
1
]]
a
=
np
.
empty
(
tensor_size
,
dtype
=
self
.
dtype
)
np
.
copyto
(
a
,
self
.
buffer
[
self
.
data_offsets
[
i
]:
self
.
data_offsets
[
i
+
1
]])
return
torch
.
from_numpy
(
a
).
long
()
class
IndexedRawTextDataset
(
IndexedDataset
):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
def
__init__
(
self
,
path
,
dictionary
,
append_eos
=
True
,
reverse_order
=
False
):
self
.
tokens_list
=
[]
self
.
lines
=
[]
self
.
sizes
=
[]
self
.
append_eos
=
append_eos
self
.
reverse_order
=
reverse_order
self
.
read_data
(
path
,
dictionary
)
self
.
size
=
len
(
self
.
tokens_list
)
def
read_data
(
self
,
path
,
dictionary
):
with
open
(
path
,
'r'
)
as
f
:
for
line
in
f
:
self
.
lines
.
append
(
line
.
strip
(
'
\n
'
))
tokens
=
Tokenizer
.
tokenize
(
line
,
dictionary
,
add_if_not_exist
=
False
,
append_eos
=
self
.
append_eos
,
reverse_order
=
self
.
reverse_order
,
).
long
()
self
.
tokens_list
.
append
(
tokens
)
self
.
sizes
.
append
(
len
(
tokens
))
self
.
sizes
=
np
.
array
(
self
.
sizes
)
def
__getitem__
(
self
,
i
):
self
.
check_index
(
i
)
return
self
.
tokens_list
[
i
]
def
get_original_text
(
self
,
i
):
self
.
check_index
(
i
)
return
self
.
lines
[
i
]
def
__del__
(
self
):
pass
def
__len__
(
self
):
return
self
.
size
@
staticmethod
def
exists
(
path
):
return
os
.
path
.
exists
(
path
)
class
IndexedRawTokenIDDataset
(
IndexedDataset
):
"""Takes a text file containing token IDs (integers written in UTF-8 format) as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
def
__init__
(
self
,
path
,
dictionary
,
append_eos
=
True
,
reverse_order
=
False
):
self
.
tokens_list
=
[]
self
.
lines
=
[]
self
.
sizes
=
[]
self
.
append_eos
=
append_eos
self
.
reverse_order
=
reverse_order
self
.
read_data
(
path
,
dictionary
)
self
.
size
=
len
(
self
.
tokens_list
)
def
read_data
(
self
,
path
,
dictionary
):
with
open
(
path
,
'r'
)
as
f
:
for
line
in
f
:
if
line
!=
'
\n
'
:
self
.
lines
.
append
(
line
.
strip
(
'
\n
'
))
nwords
=
len
(
line
.
split
(
' '
))
tokens
=
torch
.
IntTensor
(
nwords
).
long
()
for
idx
,
tok
in
enumerate
(
line
.
split
(
' '
)):
tokens
[
idx
]
=
int
(
tok
)
#tokens = line.split(' ')
self
.
tokens_list
.
append
(
tokens
)
self
.
sizes
.
append
(
len
(
tokens
))
self
.
sizes
=
np
.
array
(
self
.
sizes
)
def
__getitem__
(
self
,
i
):
self
.
check_index
(
i
)
return
self
.
tokens_list
[
i
]
def
get_original_text
(
self
,
i
):
self
.
check_index
(
i
)
return
self
.
lines
[
i
]
def
__del__
(
self
):
pass
def
__len__
(
self
):
return
self
.
size
@
staticmethod
def
exists
(
path
):
return
os
.
path
.
exists
(
path
)
class
IndexedDatasetBuilder
(
object
):
element_sizes
=
{
np
.
uint8
:
1
,
np
.
int8
:
1
,
np
.
int16
:
2
,
np
.
int32
:
4
,
np
.
int64
:
8
,
np
.
float
:
4
,
np
.
double
:
8
}
def
__init__
(
self
,
out_file
,
dtype
=
np
.
int32
):
self
.
out_file
=
open
(
out_file
,
'wb'
)
self
.
dtype
=
dtype
self
.
data_offsets
=
[
0
]
self
.
dim_offsets
=
[
0
]
self
.
sizes
=
[]
self
.
element_size
=
self
.
element_sizes
[
self
.
dtype
]
def
add_item
(
self
,
tensor
):
bytes
=
self
.
out_file
.
write
(
np
.
array
(
tensor
.
numpy
(),
dtype
=
self
.
dtype
))
self
.
data_offsets
.
append
(
self
.
data_offsets
[
-
1
]
+
bytes
/
self
.
element_size
)
for
s
in
tensor
.
size
():
self
.
sizes
.
append
(
s
)
self
.
dim_offsets
.
append
(
self
.
dim_offsets
[
-
1
]
+
len
(
tensor
.
size
()))
def
finalize
(
self
,
index_file
):
self
.
out_file
.
close
()
index
=
open
(
index_file
,
'wb'
)
index
.
write
(
b
'TNTIDX
\x00\x00
'
)
index
.
write
(
struct
.
pack
(
'<Q'
,
1
))
index
.
write
(
struct
.
pack
(
'<QQ'
,
code
(
self
.
dtype
),
self
.
element_size
))
index
.
write
(
struct
.
pack
(
'<QQ'
,
len
(
self
.
data_offsets
)
-
1
,
len
(
self
.
sizes
)))
write_longs
(
index
,
self
.
dim_offsets
)
write_longs
(
index
,
self
.
data_offsets
)
write_longs
(
index
,
self
.
sizes
)
index
.
close
()
implementations/pytorch/fairseq/data/language_pair_dataset.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
numpy
as
np
import
torch
from
.
import
data_utils
,
FairseqDataset
def
collate
(
samples
,
pad_idx
,
eos_idx
,
left_pad_source
=
True
,
left_pad_target
=
False
,
bsz_mult
=
8
,
seq_len_multiple
=
1
):
if
len
(
samples
)
==
0
:
return
{}
def
merge
(
key
,
left_pad
,
move_eos_to_beginning
=
False
):
return
data_utils
.
collate_tokens
(
[
s
[
key
]
for
s
in
samples
],
pad_idx
,
eos_idx
,
left_pad
,
move_eos_to_beginning
,
bsz_mult
,
seq_len_multiple
)
id
=
torch
.
LongTensor
([
s
[
'id'
]
for
s
in
samples
])
src_tokens
=
merge
(
'source'
,
left_pad
=
left_pad_source
)
# sort by descending source length
src_lengths
=
torch
.
LongTensor
([
s
[
'source'
].
numel
()
for
s
in
samples
])
prev_output_tokens
=
None
target
=
None
if
samples
[
0
].
get
(
'target'
,
None
)
is
not
None
:
target
=
merge
(
'target'
,
left_pad
=
left_pad_target
)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens
=
merge
(
'target'
,
left_pad
=
left_pad_target
,
move_eos_to_beginning
=
True
,
)
ntokens
=
sum
(
len
(
s
[
'target'
])
for
s
in
samples
)
else
:
ntokens
=
sum
(
len
(
s
[
'source'
])
for
s
in
samples
)
return
{
'id'
:
id
,
'ntokens'
:
ntokens
,
'net_input'
:
{
'src_tokens'
:
src_tokens
,
'src_lengths'
:
src_lengths
,
'prev_output_tokens'
:
prev_output_tokens
,
},
'target'
:
target
,
}
class
LanguagePairDataset
(
FairseqDataset
):
"""A pair of torch.utils.data.Datasets."""
def
__init__
(
self
,
src
,
src_sizes
,
src_dict
,
tgt
=
None
,
tgt_sizes
=
None
,
tgt_dict
=
None
,
left_pad_source
=
True
,
left_pad_target
=
False
,
max_source_positions
=
256
,
max_target_positions
=
256
,
seq_len_multiple
=
1
,
shuffle
=
True
):
if
tgt_dict
is
not
None
:
assert
src_dict
.
pad
()
==
tgt_dict
.
pad
()
assert
src_dict
.
eos
()
==
tgt_dict
.
eos
()
self
.
src
=
src
self
.
tgt
=
tgt
self
.
src_sizes
=
np
.
array
(
src_sizes
)
self
.
tgt_sizes
=
np
.
array
(
tgt_sizes
)
if
tgt_sizes
is
not
None
else
None
self
.
src_dict
=
src_dict
self
.
tgt_dict
=
tgt_dict
self
.
left_pad_source
=
left_pad_source
self
.
left_pad_target
=
left_pad_target
self
.
max_source_positions
=
max_source_positions
self
.
max_target_positions
=
max_target_positions
self
.
seq_len_multiple
=
seq_len_multiple
self
.
shuffle
=
shuffle
print
(
"| Sentences are being padded to multiples of: {}"
.
format
(
self
.
seq_len_multiple
))
def
__getitem__
(
self
,
index
):
return
{
'id'
:
index
,
'source'
:
self
.
src
[
index
],
'target'
:
self
.
tgt
[
index
]
if
self
.
tgt
is
not
None
else
None
,
}
def
__len__
(
self
):
return
len
(
self
.
src
)
def
collater
(
self
,
samples
):
"""Merge a list of samples to form a mini-batch."""
return
collate
(
samples
,
pad_idx
=
self
.
src_dict
.
pad
(),
eos_idx
=
self
.
src_dict
.
eos
(),
left_pad_source
=
self
.
left_pad_source
,
left_pad_target
=
self
.
left_pad_target
,
bsz_mult
=
8
,
seq_len_multiple
=
self
.
seq_len_multiple
,
)
def
get_dummy_batch
(
self
,
max_tokens_per_batch
,
max_positions
,
src_len
=
256
,
tgt_len
=
256
):
max_source_positions
,
max_target_positions
=
self
.
_get_max_positions
(
max_positions
)
src_len
,
tgt_len
=
min
(
src_len
,
max_source_positions
),
min
(
tgt_len
,
max_target_positions
)
n_seq_per_batch_based_on_longest_seq
=
max_tokens_per_batch
//
max
(
src_len
,
tgt_len
)
return
self
.
collater
([
{
'id'
:
i
,
'source'
:
self
.
src_dict
.
dummy_sentence
(
src_len
),
'target'
:
self
.
tgt_dict
.
dummy_sentence
(
tgt_len
)
if
self
.
tgt_dict
is
not
None
else
None
,
}
for
i
in
range
(
n_seq_per_batch_based_on_longest_seq
)
])
def
num_tokens
(
self
,
index
):
"""Return an example's length (number of tokens), used for batching.
Args:
index: points to the sequence pair
"""
n_tok_per_seq
=
max
(
self
.
src_sizes
[
index
],
self
.
tgt_sizes
[
index
]
if
self
.
tgt_sizes
is
not
None
else
0
)
assert
self
.
seq_len_multiple
>
0
,
"Padding multiple has to be greater than 0"
n_tok_per_seq
=
(
n_tok_per_seq
+
self
.
seq_len_multiple
-
1
)
//
self
.
seq_len_multiple
*
self
.
seq_len_multiple
# Padded seq len, rounded up to next multiple
return
n_tok_per_seq
def
ordered_indices
(
self
,
seed
=
None
):
"""Ordered indices for batching."""
if
self
.
shuffle
:
indices
=
np
.
random
.
RandomState
(
seed
).
permutation
(
len
(
self
))
else
:
indices
=
np
.
arange
(
len
(
self
))
if
self
.
tgt_sizes
is
not
None
:
indices
=
indices
[
np
.
argsort
(
self
.
tgt_sizes
[
indices
],
kind
=
'mergesort'
)]
return
indices
[
np
.
argsort
(
self
.
src_sizes
[
indices
],
kind
=
'mergesort'
)]
def
valid_size
(
self
,
index
,
max_positions
):
"""Check if an example's size is valid according to max_positions."""
max_source_positions
,
max_target_positions
=
self
.
_get_max_positions
(
max_positions
)
return
(
self
.
src_sizes
[
index
]
<=
max_source_positions
and
(
self
.
tgt_sizes
is
None
or
self
.
tgt_sizes
[
index
]
<=
max_target_positions
)
)
def
_get_max_positions
(
self
,
max_positions
):
if
max_positions
is
None
:
return
self
.
max_source_positions
,
self
.
max_target_positions
assert
len
(
max_positions
)
==
2
max_src_pos
,
max_tgt_pos
=
max_positions
return
min
(
self
.
max_source_positions
,
max_src_pos
),
min
(
self
.
max_target_positions
,
max_tgt_pos
)
def
collater_isolated
(
samples
,
seq_len_multiple
,
left_pad_source
,
left_pad_target
):
"""Merge a list of samples to form a mini-batch."""
return
collate
(
samples
,
pad_idx
=
1
,
eos_idx
=
2
,
left_pad_source
=
left_pad_source
,
left_pad_target
=
left_pad_target
,
bsz_mult
=
8
,
seq_len_multiple
=
seq_len_multiple
,
)
def
get_dummy_batch_isolated
(
max_tokens_per_batch
,
max_positions
,
seq_len_multiple
):
'''Creates a dummy batch'''
max_source_positions
,
max_target_positions
=
max_positions
[
0
],
max_positions
[
1
]
src_len
,
tgt_len
=
max_source_positions
,
max_target_positions
n_seq_per_batch_based_on_longest_seq
=
max_tokens_per_batch
//
max
(
src_len
,
tgt_len
)
nspecial
=
3
ntok_alloc
=
33712
eos_id
=
2
dummy_seq_src
=
torch
.
Tensor
(
src_len
).
uniform_
(
nspecial
+
1
,
ntok_alloc
).
long
()
dummy_seq_src
[
-
1
]
=
eos_id
dummy_seq_tgt
=
torch
.
Tensor
(
tgt_len
).
uniform_
(
nspecial
+
1
,
ntok_alloc
).
long
()
dummy_seq_tgt
[
-
1
]
=
eos_id
return
collater_isolated
([
{
'id'
:
i
,
'source'
:
dummy_seq_src
,
'target'
:
dummy_seq_tgt
}
for
i
in
range
(
n_seq_per_batch_based_on_longest_seq
)
],
seq_len_multiple
,
left_pad_source
=
True
,
left_pad_target
=
False
,
)
implementations/pytorch/fairseq/data/monolingual_dataset.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
numpy
as
np
import
torch
from
.
import
data_utils
,
FairseqDataset
def
collate
(
samples
,
pad_idx
,
eos_idx
):
if
len
(
samples
)
==
0
:
return
{}
def
merge
(
key
):
return
data_utils
.
collate_tokens
(
[
s
[
key
]
for
s
in
samples
],
pad_idx
,
eos_idx
,
left_pad
=
False
,
)
return
{
'id'
:
torch
.
LongTensor
([
s
[
'id'
]
for
s
in
samples
]),
'ntokens'
:
sum
(
len
(
s
[
'target'
])
for
s
in
samples
),
'net_input'
:
{
'src_tokens'
:
merge
(
'source'
),
},
'target'
:
merge
(
'target'
),
}
class
MonolingualDataset
(
FairseqDataset
):
"""A wrapper around torch.utils.data.Dataset for monolingual data."""
def
__init__
(
self
,
dataset
,
sizes
,
vocab
,
shuffle
):
self
.
dataset
=
dataset
self
.
sizes
=
np
.
array
(
sizes
)
self
.
vocab
=
vocab
self
.
shuffle
=
shuffle
def
__getitem__
(
self
,
index
):
source
,
target
=
self
.
dataset
[
index
]
return
{
'id'
:
index
,
'source'
:
source
,
'target'
:
target
}
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
collater
(
self
,
samples
):
"""Merge a list of samples to form a mini-batch."""
return
collate
(
samples
,
self
.
vocab
.
pad
(),
self
.
vocab
.
eos
())
def
get_dummy_batch
(
self
,
num_tokens
,
max_positions
,
tgt_len
=
128
):
assert
isinstance
(
max_positions
,
float
)
or
isinstance
(
max_positions
,
int
)
tgt_len
=
min
(
tgt_len
,
max_positions
)
bsz
=
num_tokens
//
tgt_len
target
=
self
.
vocab
.
dummy_sentence
(
tgt_len
+
1
)
source
,
target
=
target
[:
-
1
],
target
[
1
:]
return
self
.
collater
([
{
'id'
:
i
,
'source'
:
source
,
'target'
:
target
}
for
i
in
range
(
bsz
)
])
def
num_tokens
(
self
,
index
):
"""Return an example's length (number of tokens), used for batching."""
source
,
target
=
self
.
dataset
[
index
]
return
len
(
source
)
def
ordered_indices
(
self
,
seed
=
None
):
"""Ordered indices for batching."""
if
self
.
shuffle
:
order
=
[
np
.
random
.
permutation
(
len
(
self
))]
else
:
order
=
[
np
.
arange
(
len
(
self
))]
order
.
append
(
np
.
flip
(
self
.
sizes
,
0
))
return
np
.
lexsort
(
order
)
def
valid_size
(
self
,
index
,
max_positions
):
"""Check if an example's size is valid according to max_positions."""
assert
isinstance
(
max_positions
,
float
)
or
isinstance
(
max_positions
,
int
)
return
self
.
sizes
[
index
]
<=
max_positions
implementations/pytorch/fairseq/data/token_block_dataset.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
math
import
numpy
as
np
import
torch
class
TokenBlockDataset
(
torch
.
utils
.
data
.
Dataset
):
"""Break a 1d tensor of tokens into blocks.
The blocks are fetched from the original tensor so no additional memory is allocated.
Args:
tokens: 1d tensor of tokens to break into blocks
sizes: sentence lengths (required for 'complete' and 'eos')
block_size: maximum block size (ignored in 'eos' break mode)
break_mode: Mode used for breaking tokens. Values can be one of:
- 'none': break tokens into equally sized blocks (up to block_size)
- 'complete': break tokens into blocks (up to block_size) such that
blocks contains complete sentences, although block_size may be
exceeded if some sentences exceed block_size
- 'eos': each block contains one sentence (block_size is ignored)
include_targets: return next tokens as targets
"""
def
__init__
(
self
,
tokens
,
sizes
,
block_size
,
break_mode
=
None
,
include_targets
=
False
):
super
().
__init__
()
self
.
tokens
=
tokens
self
.
total_size
=
len
(
tokens
)
self
.
include_targets
=
include_targets
self
.
slice_indices
=
[]
if
break_mode
is
None
or
break_mode
==
'none'
:
length
=
math
.
ceil
(
len
(
tokens
)
/
block_size
)
def
block_at
(
i
):
start
=
i
*
block_size
end
=
min
(
start
+
block_size
,
len
(
tokens
))
return
(
start
,
end
)
self
.
slice_indices
=
[
block_at
(
i
)
for
i
in
range
(
length
)]
elif
break_mode
==
'complete'
:
assert
sizes
is
not
None
and
sum
(
sizes
)
==
len
(
tokens
),
'{} != {}'
.
format
(
sum
(
sizes
),
len
(
tokens
))
tok_idx
=
0
sz_idx
=
0
curr_size
=
0
while
sz_idx
<
len
(
sizes
):
if
curr_size
+
sizes
[
sz_idx
]
<=
block_size
or
curr_size
==
0
:
curr_size
+=
sizes
[
sz_idx
]
sz_idx
+=
1
else
:
self
.
slice_indices
.
append
((
tok_idx
,
tok_idx
+
curr_size
))
tok_idx
+=
curr_size
curr_size
=
0
if
curr_size
>
0
:
self
.
slice_indices
.
append
((
tok_idx
,
tok_idx
+
curr_size
))
elif
break_mode
==
'eos'
:
assert
sizes
is
not
None
and
sum
(
sizes
)
==
len
(
tokens
),
'{} != {}'
.
format
(
sum
(
sizes
),
len
(
tokens
))
curr
=
0
for
sz
in
sizes
:
# skip samples with just 1 example (which would be just the eos token)
if
sz
>
1
:
self
.
slice_indices
.
append
((
curr
,
curr
+
sz
))
curr
+=
sz
else
:
raise
ValueError
(
'Invalid break_mode: '
+
break_mode
)
self
.
sizes
=
np
.
array
([
e
-
s
for
s
,
e
in
self
.
slice_indices
])
def
__getitem__
(
self
,
index
):
s
,
e
=
self
.
slice_indices
[
index
]
item
=
torch
.
LongTensor
(
self
.
tokens
[
s
:
e
])
if
self
.
include_targets
:
# target is the sentence, for source, rotate item one token to the left (would start with eos)
if
s
==
0
:
source
=
np
.
concatenate
([
self
.
tokens
[
-
1
:],
self
.
tokens
[
0
:
e
-
1
]])
else
:
source
=
self
.
tokens
[
s
-
1
:
e
-
1
]
return
torch
.
LongTensor
(
source
),
item
return
item
def
__len__
(
self
):
return
len
(
self
.
slice_indices
)
implementations/pytorch/fairseq/distributed_utils.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
pickle
import
torch.distributed
from
fairseq
import
utils
def
is_master
(
args
):
return
args
.
distributed_rank
==
0
def
distributed_init
(
args
):
if
args
.
distributed_world_size
==
1
:
raise
ValueError
(
'Cannot initialize distributed with distributed_world_size=1'
)
print
(
'| distributed init (rank {}): {}'
.
format
(
args
.
distributed_rank
,
args
.
distributed_init_method
),
flush
=
True
)
if
args
.
distributed_init_method
.
startswith
(
'tcp://'
):
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
init_method
=
args
.
distributed_init_method
,
world_size
=
args
.
distributed_world_size
,
rank
=
args
.
distributed_rank
)
elif
args
.
distributed_init_method
.
startswith
(
'env://'
):
import
os
print
(
"| distributed env init. MASTER_ADDR: "
+
os
.
environ
[
'MASTER_ADDR'
]
+
", MASTER_PORT: "
+
os
.
environ
[
'MASTER_PORT'
]
+
", WORLD_SIZE: "
+
os
.
environ
[
'WORLD_SIZE'
]
+
", RANK: "
+
os
.
environ
[
'RANK'
],
flush
=
True
)
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
init_method
=
args
.
distributed_init_method
)
print
(
"| distributed init done!"
,
flush
=
True
)
args
.
distributed_world_size
=
int
(
os
.
environ
[
'WORLD_SIZE'
])
else
:
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
init_method
=
args
.
distributed_init_method
,
world_size
=
args
.
distributed_world_size
)
args
.
distributed_rank
=
torch
.
distributed
.
get_rank
()
if
not
is_master
(
args
):
suppress_output
()
return
args
.
distributed_rank
def
suppress_output
():
"""Suppress printing on the current device. Force printing with `force=True`."""
import
builtins
as
__builtin__
builtin_print
=
__builtin__
.
print
def
print
(
*
args
,
**
kwargs
):
if
'force'
in
kwargs
:
force
=
kwargs
.
pop
(
'force'
)
if
force
:
builtin_print
(
*
args
,
**
kwargs
)
__builtin__
.
print
=
print
def
all_gather_list
(
data
,
max_size
=
16384
):
"""Gathers arbitrary data from all nodes into a list."""
world_size
=
torch
.
distributed
.
get_world_size
()
if
not
hasattr
(
all_gather_list
,
'_in_buffer'
)
or
\
max_size
!=
len
(
all_gather_list
.
_in_buffer
):
all_gather_list
.
_in_buffer
=
torch
.
cuda
.
ByteTensor
(
max_size
)
all_gather_list
.
_out_buffers
=
[
torch
.
cuda
.
ByteTensor
(
max_size
)
for
i
in
range
(
world_size
)
]
in_buffer
=
all_gather_list
.
_in_buffer
out_buffers
=
all_gather_list
.
_out_buffers
enc
=
pickle
.
dumps
(
data
)
enc_size
=
len
(
enc
)
if
enc_size
+
2
>
max_size
:
raise
ValueError
(
'encoded data exceeds max_size: {}'
.
format
(
enc_size
+
2
))
assert
max_size
<
255
*
256
in_buffer
[
0
]
=
enc_size
//
255
# this encoding works for max_size < 65k
in_buffer
[
1
]
=
enc_size
%
255
in_buffer
[
2
:
enc_size
+
2
]
=
torch
.
ByteTensor
(
list
(
enc
))
torch
.
distributed
.
all_gather
(
out_buffers
,
in_buffer
.
cuda
())
result
=
[]
for
i
in
range
(
world_size
):
out_buffer
=
out_buffers
[
i
]
size
=
(
255
*
utils
.
item
(
out_buffer
[
0
]))
+
utils
.
item
(
out_buffer
[
1
])
result
.
append
(
pickle
.
loads
(
bytes
(
out_buffer
[
2
:
size
+
2
].
tolist
()))
)
return
result
implementations/pytorch/fairseq/fp16_trainer.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Train a network on multiple GPUs.
"""
import
torch
import
ctypes
from
fairseq
import
optim
,
utils
from
fairseq.meters
import
AverageMeter
from
fairseq.optim
import
lr_scheduler
from
fairseq.trainer
import
Trainer
def
fused_norm
(
input
):
return
input
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
).
item
()
class
DynamicLossScaler
:
def
__init__
(
self
,
init_scale
=
2.
**
15
,
scale_factor
=
2.
,
scale_window
=
2000
):
self
.
loss_scale
=
init_scale
self
.
scale_factor
=
scale_factor
self
.
scale_window
=
scale_window
self
.
_iter
=
0
self
.
_last_overflow_iter
=
-
1
def
update_scale
(
self
,
overflow
):
if
overflow
:
self
.
loss_scale
/=
self
.
scale_factor
self
.
_last_overflow_iter
=
self
.
_iter
elif
(
self
.
_iter
-
self
.
_last_overflow_iter
)
%
self
.
scale_window
==
0
:
self
.
loss_scale
*=
self
.
scale_factor
self
.
_iter
+=
1
@
staticmethod
def
has_overflow
(
grad_norm
):
# detect inf and nan
if
grad_norm
==
float
(
'inf'
)
or
grad_norm
!=
grad_norm
:
return
True
return
False
class
FP16Trainer
(
Trainer
):
"""Modified trainer for FP16.
We maintain two copies of the model's parameters, both in FP16 and FP32.
We do forward/backward with FP16 and compute the loss + optimize with FP32.
"""
def
__init__
(
self
,
args
,
task
,
model
,
criterion
,
allreduce_communicators
=
None
):
super
().
__init__
(
args
,
task
,
model
,
criterion
,
allreduce_communicators
)
# convert model to FP16 (but keep criterion FP32)
self
.
model
.
half
()
# broadcast initial weights from rank=0
# this broadcast isn't required in DistributedFP16Trainer because the
# broadcast is done by DistributedFusedAdam
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
for
p
in
self
.
model
.
parameters
():
torch
.
distributed
.
broadcast
(
p
,
0
)
# dynamically scale loss to reduce overflow
self
.
scaler
=
DynamicLossScaler
(
init_scale
=
2.
**
7
)
self
.
meters
[
'loss_scale'
]
=
AverageMeter
()
self
.
grad_denom
=
1.0
if
self
.
args
.
enable_parallel_backward_allred_opt
:
import
numpy
as
np
self
.
_flat_grads_parallel
=
torch
.
tensor
([],
dtype
=
torch
.
float16
).
cuda
()
self
.
_grads_info
=
[]
grads_size
=
0
p_offset
=
0
for
p_i
,
p
in
enumerate
([
p
for
p
in
self
.
model
.
parameters
()
if
p
.
requires_grad
]):
p_grads_size
=
np
.
prod
(
list
(
p
.
size
()))
grads_size
+=
p_grads_size
# register hooks
def
wrapper
(
param
,
param_i
,
param_grads_size
,
param_offset
):
def
allreduce_hook
(
grad
):
self
.
_do_allreduce
(
param_i
,
param_grads_size
,
param_offset
,
grad
)
if
param
.
requires_grad
:
param
.
register_hook
(
allreduce_hook
)
# print(p_i, p.size(), p_grads_size, p_offset)
self
.
_grads_info
.
append
({
"param_grads_size"
:
p_grads_size
,
"param_offset"
:
p_offset
})
wrapper
(
p
,
p_i
,
p_grads_size
,
p_offset
)
p_offset
+=
p_grads_size
self
.
_flat_grads_parallel
.
resize_
(
grads_size
)
# print(grads_size, len(self._flat_grads_parallel), self._flat_grads_parallel.dtype, self._flat_grads_parallel.get_device())
self
.
_allreduce_flush_min_threshold
=
self
.
args
.
parallel_backward_allred_opt_threshold
print
(
"| parallel all-reduce ENABLED. all-reduce threshold: "
+
str
(
self
.
_allreduce_flush_min_threshold
))
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
self
.
_allreduce_processed_idx
=
len
(
self
.
_grads_info
)
-
1
self
.
_num_allreduce_sent
=
0
print
(
"| # of parallel all-reduce cuda streams: "
+
str
(
self
.
args
.
parallel_backward_allred_cuda_nstreams
))
if
allreduce_communicators
:
self
.
_allreduce_groups
=
allreduce_communicators
[
0
]
self
.
_allreduce_streams
=
allreduce_communicators
[
1
]
else
:
raise
RuntimeError
(
'Moved communicator init before RUN_START (invalid code path)'
)
self
.
_allreduce_groups
=
[
torch
.
distributed
.
new_group
()
for
_
in
range
(
self
.
args
.
parallel_backward_allred_cuda_nstreams
)]
self
.
_allreduce_streams
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
args
.
parallel_backward_allred_cuda_nstreams
)]
if
self
.
args
.
enable_parallel_backward_allred_opt_correctness_check
:
self
.
_num_grads_generated
=
0
self
.
_all_grads_generated
=
False
self
.
_allreduce_schedule
=
[]
def
_get_flush_bucket
(
self
):
# print([1 if x else 0 for x in self._grads_generated])
flush_bucket
=
[]
size
=
0
allreduce_processed_idx_list
=
[]
allreduce_processed_end_idx
=
self
.
_allreduce_processed_idx
remaining_grads_for_allreduce
=
self
.
_grads_generated
[
allreduce_processed_end_idx
-
len
(
self
.
_grads_generated
)::
-
1
]
# print([1 if x else 0 for x in remaining_grads_for_allreduce])
for
s
in
remaining_grads_for_allreduce
:
# print(s,allreduce_processed_end_idx,size)
if
s
:
allreduce_processed_idx_list
.
append
(
allreduce_processed_end_idx
)
size
+=
self
.
_grads_info
[
allreduce_processed_end_idx
][
"param_grads_size"
]
allreduce_processed_end_idx
-=
1
else
:
break
# print(size, allreduce_processed_idx_list)
ignore_threshold
=
all
(
self
.
_grads_generated
)
if
size
>=
self
.
_allreduce_flush_min_threshold
or
ignore_threshold
:
# for i in allreduce_processed_idx_list:
# print(i, self._grads_info[i]["param_grads_size"], self._grads_info[i]["param_offset"],size)
if
allreduce_processed_idx_list
:
start
=
self
.
_grads_info
[(
allreduce_processed_idx_list
[
-
1
])][
"param_offset"
]
end
=
start
+
size
# print("->", start, end)
flush_bucket
=
[
start
,
end
]
self
.
_allreduce_processed_idx
=
allreduce_processed_end_idx
if
self
.
_allreduce_processed_idx
<
0
:
# reset
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
self
.
_allreduce_processed_idx
=
len
(
self
.
_grads_info
)
-
1
return
flush_bucket
def
_do_allreduce
(
self
,
param_i
,
param_grads_size
,
param_offset
,
grad
):
if
self
.
_last_step
==
False
:
# # ----------------------
# # debugging: do all-reduce in the same stream
# print(self._last_step, self._grads_total, len(self._backward_grads_schedule), param_i, param_offset, param_grads_size, grad.size(), grad.numel(), grad.dtype)
# self._flat_grads_parallel[param_offset:param_offset+param_grads_size].copy_(grad.view(-1))
# self._flat_grads_parallel[param_offset:param_offset+param_grads_size].div_(self.args.distributed_world_size)
# torch.distributed.all_reduce(self._flat_grads_parallel[param_offset:param_offset+param_grads_size])
# # ----------------------
# # ----------------------
# # option #1: send per-layer gradients
# torch.div(grad.view(-1), self.args.distributed_world_size, out=self._flat_grads_parallel[param_offset:param_offset+param_grads_size])
# orig_stream = torch.cuda.current_stream()
# self._reduction_stream.wait_stream(orig_stream)
# with torch.cuda.stream(self._reduction_stream):
# torch.distributed.all_reduce(self._flat_grads_parallel[param_offset:param_offset+param_grads_size])
# # ----------------------
# ----------------------
# option #2: bucket all-reduce based on threshold
self
.
_flat_grads_parallel
.
record_stream
(
torch
.
cuda
.
current_stream
())
torch
.
div
(
grad
.
view
(
-
1
),
self
.
args
.
distributed_world_size
,
out
=
self
.
_flat_grads_parallel
[
param_offset
:
param_offset
+
param_grads_size
])
self
.
_grads_generated
[
param_i
]
=
True
flush_bucket
=
self
.
_get_flush_bucket
()
if
flush_bucket
:
start
=
flush_bucket
[
0
]
end
=
flush_bucket
[
1
]
# print("->", start, end)
if
self
.
args
.
enable_parallel_backward_allred_opt_correctness_check
and
not
self
.
_all_grads_generated
:
self
.
_allreduce_schedule
.
append
(
flush_bucket
)
# orig_stream = torch.cuda.current_stream()
# self._reduction_stream.wait_stream(orig_stream)
# with torch.cuda.stream(self._reduction_stream):
# torch.distributed.all_reduce(self._flat_grads_parallel[start:end])
orig_stream
=
torch
.
cuda
.
current_stream
()
allreduce_group
=
self
.
_allreduce_groups
[
self
.
_num_allreduce_sent
%
len
(
self
.
_allreduce_groups
)]
allreduce_stream
=
self
.
_allreduce_streams
[
self
.
_num_allreduce_sent
%
len
(
self
.
_allreduce_streams
)]
allreduce_stream
.
wait_stream
(
orig_stream
)
with
torch
.
cuda
.
stream
(
allreduce_stream
):
self
.
_flat_grads_parallel
.
record_stream
(
torch
.
cuda
.
current_stream
())
torch
.
distributed
.
all_reduce
(
self
.
_flat_grads_parallel
[
start
:
end
],
group
=
allreduce_group
)
self
.
_num_allreduce_sent
+=
1
if
self
.
args
.
enable_parallel_backward_allred_opt_correctness_check
:
self
.
_num_grads_generated
+=
1
if
self
.
_num_grads_generated
==
len
(
self
.
_grads_info
):
self
.
_all_grads_generated
=
True
# ----------------------
def
_build_optimizer
(
self
):
# create FP32 copy of parameters and grads
params
=
[
p
for
p
in
self
.
model
.
parameters
()
if
p
.
requires_grad
]
total_param_size
=
sum
(
p
.
data
.
numel
()
for
p
in
params
)
self
.
fp32_params
=
params
[
0
].
new
(
0
).
float
().
new
(
total_param_size
)
offset
=
0
for
p
in
params
:
numel
=
p
.
data
.
numel
()
self
.
fp32_params
[
offset
:
offset
+
numel
].
copy_
(
p
.
data
.
view
(
-
1
))
offset
+=
numel
self
.
fp32_params
=
torch
.
nn
.
Parameter
(
self
.
fp32_params
)
#self.fp32_params.grad = self.fp32_params.data.new(total_param_size)
# create optimizer using the copied FP32 params
self
.
_optimizer
=
optim
.
build_optimizer
(
self
.
args
,
[
self
.
fp32_params
])
self
.
lr_scheduler
=
lr_scheduler
.
build_lr_scheduler
(
self
.
args
,
self
.
optimizer
)
def
save_checkpoint
(
self
,
filename
,
extra_state
):
"""Save all training state in a checkpoint file."""
extra_state
[
'loss_scale'
]
=
self
.
scaler
.
loss_scale
super
().
save_checkpoint
(
filename
,
extra_state
)
def
load_checkpoint
(
self
,
filename
):
"""Load all training state from a checkpoint file."""
extra_state
=
super
().
load_checkpoint
(
filename
)
if
extra_state
is
not
None
and
'loss_scale'
in
extra_state
:
self
.
scaler
.
loss_scale
=
extra_state
[
'loss_scale'
]
return
extra_state
def
zero_grad
(
self
):
# zero both the FP16 and FP32 grads
# self.model.zero_grad() # FP16
# self.optimizer.zero_grad() # FP32
# r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
for
p
in
self
.
model
.
parameters
():
p
.
grad
=
None
def
_backward
(
self
,
loss
):
self
.
meters
[
'loss_scale'
].
reset
()
self
.
meters
[
'loss_scale'
].
update
(
self
.
scaler
.
loss_scale
)
if
loss
is
not
None
:
# dynamically rescale loss to stay in FP16 range
loss
=
loss
*
self
.
scaler
.
loss_scale
return
super
().
_backward
(
loss
)
def
_all_reduce_and_rescale
(
self
,
grad_denom
,
has_grad
=
True
):
# undo effect of dynamic loss scaling on gradients
self
.
grad_denom
=
grad_denom
*
self
.
scaler
.
loss_scale
if
self
.
args
.
distributed_world_size
>
1
:
self
.
grad_denom
/=
self
.
args
.
distributed_world_size
if
not
self
.
args
.
enable_parallel_backward_allred_opt
or
self
.
_last_step
:
# flatten grads into a single buffer
self
.
_flat_grads
=
self
.
_get_flat_grads
(
out
=
None
,
has_grad
=
has_grad
)
# scale gradients to avoid overflow in all-reduce
self
.
_flat_grads
.
div_
(
self
.
args
.
distributed_world_size
)
# all-reduce flat grads
torch
.
distributed
.
all_reduce
(
self
.
_flat_grads
)
else
:
# torch.cuda.current_stream().wait_stream(self._reduction_stream)
for
allreduce_stream
in
self
.
_allreduce_streams
:
torch
.
cuda
.
current_stream
().
wait_stream
(
allreduce_stream
)
self
.
_flat_grads_parallel
.
record_stream
(
torch
.
cuda
.
current_stream
())
self
.
_flat_grads
=
self
.
_flat_grads_parallel
if
self
.
args
.
enable_parallel_backward_allred_opt_correctness_check
:
# # ----------------------
# # option #1: send per-layer gradients
# grads = self._get_grads()
# offset = 0
# for g in grads:
# numel = g.numel()
# out = grads[0].new(numel).zero_()
# out.copy_(g.view(-1))
# out.div_(self.args.distributed_world_size)
# torch.distributed.all_reduce(out)
# is_parallel_grads_finite = torch.all(torch.isfinite(self._flat_grads_parallel[offset:offset+numel]))
# is_out_finite = torch.all(torch.isfinite(out))
# assert(is_out_finite == is_parallel_grads_finite)
# if not is_out_finite:
# print("| OVERLAP-CHECK: check inf/nan detected. this batch should be skipped")
# else:
# if not torch.all(torch.eq(out, self._flat_grads_parallel[offset:offset+numel])):
# print(out[0:10], self._flat_grads_parallel[offset:offset+10])
# # for i,_ in enumerate(out):
# # if out[i] != self._flat_grads_parallel[i]:
# # print(i,out[i],self._flat_grads_parallel[i])
# raise RuntimeError('w-gradients received in parallel vs. end differ')
# offset += numel
# # ----------------------
# ----------------------
# option #2: bucket all-reduce based on threshold
# print(self._allreduce_schedule)
out
=
self
.
_get_flat_grads
()
out
.
div_
(
self
.
args
.
distributed_world_size
)
grads_size
=
0
for
s
in
self
.
_allreduce_schedule
:
start
=
s
[
0
]
end
=
s
[
1
]
assert
(
end
>
start
)
grads_size
+=
(
end
-
start
)
torch
.
distributed
.
all_reduce
(
out
[
start
:
end
])
is_parallel_grads_finite
=
torch
.
all
(
torch
.
isfinite
(
self
.
_flat_grads_parallel
[
start
:
end
]))
is_out_finite
=
torch
.
all
(
torch
.
isfinite
(
out
[
start
:
end
]))
assert
(
is_out_finite
==
is_parallel_grads_finite
)
if
not
is_out_finite
:
print
(
"| OVERLAP-CHECK: check inf/nan detected. this batch should be skipped"
)
else
:
if
not
torch
.
all
(
torch
.
eq
(
out
[
start
:
end
],
self
.
_flat_grads_parallel
[
start
:
end
])):
print
(
start
,
end
,
out
[
start
:
end
],
self
.
_flat_grads_parallel
[
start
:
end
])
raise
RuntimeError
(
'w-gradients received in parallel vs. end differ'
)
assert
(
grads_size
==
len
(
self
.
_flat_grads_parallel
))
# ----------------------
else
:
# flatten grads into a single buffer
self
.
_flat_grads
=
self
.
_get_flat_grads
(
out
=
None
,
has_grad
=
has_grad
)
grad_norm
=
fused_norm
(
self
.
_flat_grads
)
# detect overflow and adjust loss scale
overflow
=
DynamicLossScaler
.
has_overflow
(
grad_norm
)
self
.
scaler
.
update_scale
(
overflow
)
if
overflow
:
if
self
.
scaler
.
loss_scale
<=
self
.
args
.
min_loss_scale
:
raise
Exception
((
'Minimum loss scale reached ({}). Your loss is probably exploding. '
'Try lowering the learning rate, using gradient clipping or '
'increasing the batch size.'
).
format
(
self
.
args
.
min_loss_scale
))
raise
OverflowError
(
'setting loss scale to: '
+
str
(
self
.
scaler
.
loss_scale
))
return
grad_norm
def
_opt
(
self
):
# take an optimization step using the FP32 params and grads
#super()._opt()
new_params
=
self
.
_flat_grads
.
new_empty
(
self
.
_flat_grads
.
size
())
self
.
optimizer
.
optimizer
.
step
(
closure
=
None
,
grads
=
[
self
.
_flat_grads
],
output_params
=
[
new_params
],
scale
=
self
.
grad_denom
)
self
.
zero_grad
()
self
.
_num_updates
+=
1
# update learning rate
self
.
lr_scheduler
.
step_update
(
self
.
_num_updates
)
# copy FP32 params back into FP16 model
offset
=
0
with
torch
.
no_grad
():
for
p
in
self
.
model
.
parameters
():
if
not
p
.
requires_grad
:
continue
numel
=
p
.
data
.
numel
()
p
.
set_
(
new_params
[
offset
:
offset
+
numel
].
view_as
(
p
.
data
))
offset
+=
numel
class
DistributedFP16Trainer
(
Trainer
):
"""Modified trainer for FP16.
We maintain two copies of the model's parameters, both in FP16 and FP32.
We do forward/backward with FP16 and compute the loss + optimize with FP32.
"""
def
__init__
(
self
,
args
,
task
,
model
,
criterion
,
allreduce_communicators
=
None
):
super
().
__init__
(
args
,
task
,
model
,
criterion
,
allreduce_communicators
)
# convert model to FP16 (but keep criterion FP32)
self
.
model
.
half
()
# dynamically scale loss to reduce overflow
self
.
scaler
=
DynamicLossScaler
(
init_scale
=
2.
**
7
)
self
.
meters
[
'loss_scale'
]
=
AverageMeter
()
# FIXME: Add more meters
self
.
grad_denom
=
1.0
assert
(
not
self
.
args
.
enable_parallel_backward_allred_opt
),
"--distributed-weight-update cannot be combined with --enable-parallel-backward-allred-opt"
def
save_checkpoint
(
self
,
filename
,
extra_state
):
"""Save all training state in a checkpoint file."""
# To-Do: gather optimizer buffer chunks before saving state
extra_state
[
'loss_scale'
]
=
self
.
scaler
.
loss_scale
super
().
save_checkpoint
(
filename
,
extra_state
)
def
load_checkpoint
(
self
,
filename
):
"""Load all training state from a checkpoint file."""
# To-Do: scatter optimizer buffer chunks after restoring state
extra_state
=
super
().
load_checkpoint
(
filename
)
if
extra_state
is
not
None
and
'loss_scale'
in
extra_state
:
self
.
scaler
.
loss_scale
=
extra_state
[
'loss_scale'
]
return
extra_state
#def zero_grad(self):
# for p in self.model.parameters():
# p.grad = None
def
_backward
(
self
,
loss
):
self
.
meters
[
'loss_scale'
].
reset
()
self
.
meters
[
'loss_scale'
].
update
(
self
.
scaler
.
loss_scale
)
if
loss
is
not
None
:
# dynamically rescale loss to stay in FP16 range
loss
=
loss
*
self
.
scaler
.
loss_scale
rval
=
super
().
_backward
(
loss
)
self
.
optimizer
.
optimizer
.
complete_reductions
()
return
rval
def
__process_overflow
(
self
,
overflow
):
self
.
scaler
.
update_scale
(
overflow
)
if
overflow
:
if
self
.
scaler
.
loss_scale
<=
self
.
args
.
min_loss_scale
:
raise
Exception
((
'Minimum loss scale reached ({}). Your loss is probably exploding. '
'Try lowering the learning rate, using gradient clipping or '
'increasing the batch size.'
).
format
(
self
.
args
.
min_loss_scale
))
raise
OverflowError
(
'setting loss scale to: '
+
str
(
self
.
scaler
.
loss_scale
))
def
_all_reduce_and_rescale
(
self
,
grad_denom
,
has_grad
=
True
):
grad_norm
=
self
.
optimizer
.
optimizer
.
L2_grad_norm
if
grad_norm
is
not
None
:
overflow
=
self
.
scaler
.
has_overflow
(
grad_norm
)
self
.
__process_overflow
(
overflow
)
return
grad_norm
else
:
return
None
def
_opt
(
self
):
self
.
optimizer
.
optimizer
.
step
(
skip_overflow_check
=
self
.
args
.
dwu_compute_L2_grad_norm
)
self
.
zero_grad
()
self
.
__process_overflow
(
False
if
self
.
args
.
dwu_compute_L2_grad_norm
or
not
self
.
optimizer
.
optimizer
.
has_overflow
else
True
)
self
.
_num_updates
+=
1
# update learning rate
self
.
lr_scheduler
.
step_update
(
self
.
_num_updates
)
implementations/pytorch/fairseq/meters.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
time
class
AverageMeter
(
object
):
"""Computes and stores the average and current value"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
class
TimeMeter
(
object
):
"""Computes the average occurrence of some event per second"""
def
__init__
(
self
,
init
=
0
):
self
.
reset
(
init
)
def
reset
(
self
,
init
=
0
):
self
.
init
=
init
self
.
start
=
time
.
time
()
self
.
n
=
0
def
update
(
self
,
val
=
1
):
self
.
n
+=
val
@
property
def
avg
(
self
):
return
self
.
n
/
self
.
elapsed_time
@
property
def
elapsed_time
(
self
):
return
self
.
init
+
(
time
.
time
()
-
self
.
start
)
class
StopwatchMeter
(
object
):
"""Computes the sum/avg duration of some event in seconds"""
def
__init__
(
self
):
self
.
reset
()
def
start
(
self
):
self
.
start_time
=
time
.
time
()
def
stop
(
self
,
n
=
1
):
if
self
.
start_time
is
not
None
:
delta
=
time
.
time
()
-
self
.
start_time
self
.
sum
+=
delta
self
.
n
+=
n
self
.
start_time
=
None
def
reset
(
self
):
self
.
sum
=
0
self
.
n
=
0
self
.
start_time
=
None
@
property
def
avg
(
self
):
return
self
.
sum
/
self
.
n
implementations/pytorch/fairseq/models/__init__.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
importlib
import
os
from
.fairseq_decoder
import
FairseqDecoder
# noqa: F401
from
.fairseq_encoder
import
FairseqEncoder
# noqa: F401
from
.fairseq_incremental_decoder
import
FairseqIncrementalDecoder
# noqa: F401
from
.fairseq_model
import
BaseFairseqModel
,
FairseqModel
,
FairseqLanguageModel
# noqa: F401
from
.composite_encoder
import
CompositeEncoder
# noqa: F401
MODEL_REGISTRY
=
{}
ARCH_MODEL_REGISTRY
=
{}
ARCH_CONFIG_REGISTRY
=
{}
def
build_model
(
args
,
task
):
return
ARCH_MODEL_REGISTRY
[
args
.
arch
].
build_model
(
args
,
task
)
def
register_model
(
name
):
"""Decorator to register a new model (e.g., LSTM)."""
def
register_model_cls
(
cls
):
if
name
in
MODEL_REGISTRY
:
raise
ValueError
(
'Cannot register duplicate model ({})'
.
format
(
name
))
if
not
issubclass
(
cls
,
BaseFairseqModel
):
raise
ValueError
(
'Model ({}: {}) must extend BaseFairseqModel'
.
format
(
name
,
cls
.
__name__
))
MODEL_REGISTRY
[
name
]
=
cls
return
cls
return
register_model_cls
def
register_model_architecture
(
model_name
,
arch_name
):
"""Decorator to register a new model architecture (e.g., lstm_luong_wmt_en_de)."""
def
register_model_arch_fn
(
fn
):
if
model_name
not
in
MODEL_REGISTRY
:
raise
ValueError
(
'Cannot register model architecture for unknown model type ({})'
.
format
(
model_name
))
if
arch_name
in
ARCH_MODEL_REGISTRY
:
raise
ValueError
(
'Cannot register duplicate model architecture ({})'
.
format
(
arch_name
))
if
not
callable
(
fn
):
raise
ValueError
(
'Model architecture must be callable ({})'
.
format
(
arch_name
))
ARCH_MODEL_REGISTRY
[
arch_name
]
=
MODEL_REGISTRY
[
model_name
]
ARCH_CONFIG_REGISTRY
[
arch_name
]
=
fn
return
fn
return
register_model_arch_fn
# automatically import any Python files in the models/ directory
for
file
in
os
.
listdir
(
os
.
path
.
dirname
(
__file__
)):
if
file
.
endswith
(
'.py'
)
and
not
file
.
startswith
(
'_'
):
module
=
file
[:
file
.
find
(
'.py'
)]
importlib
.
import_module
(
'fairseq.models.'
+
module
)
implementations/pytorch/fairseq/models/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
9e8a8c05
File added
implementations/pytorch/fairseq/models/__pycache__/composite_encoder.cpython-310.pyc
0 → 100644
View file @
9e8a8c05
File added
implementations/pytorch/fairseq/models/__pycache__/fairseq_decoder.cpython-310.pyc
0 → 100644
View file @
9e8a8c05
File added
implementations/pytorch/fairseq/models/__pycache__/fairseq_encoder.cpython-310.pyc
0 → 100644
View file @
9e8a8c05
File added
implementations/pytorch/fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-310.pyc
0 → 100644
View file @
9e8a8c05
File added
implementations/pytorch/fairseq/models/__pycache__/fairseq_model.cpython-310.pyc
0 → 100644
View file @
9e8a8c05
File added
implementations/pytorch/fairseq/models/__pycache__/fconv.cpython-310.pyc
0 → 100644
View file @
9e8a8c05
File added
implementations/pytorch/fairseq/models/__pycache__/fconv_self_att.cpython-310.pyc
0 → 100644
View file @
9e8a8c05
File added
implementations/pytorch/fairseq/models/__pycache__/lstm.cpython-310.pyc
0 → 100644
View file @
9e8a8c05
File added
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment