Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
bd4db8fb
Commit
bd4db8fb
authored
Jun 24, 2018
by
Myle Ott
Browse files
Misc changes for pytorch-translate
parent
c6fe9fc5
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
29 additions
and
21 deletions
+29
-21
fairseq/data/dictionary.py
fairseq/data/dictionary.py
+4
-4
fairseq/data/indexed_dataset.py
fairseq/data/indexed_dataset.py
+8
-3
fairseq/fp16_trainer.py
fairseq/fp16_trainer.py
+1
-1
fairseq/optim/lr_scheduler/fixed_schedule.py
fairseq/optim/lr_scheduler/fixed_schedule.py
+3
-0
fairseq/tasks/language_modeling.py
fairseq/tasks/language_modeling.py
+1
-1
fairseq/tasks/translation.py
fairseq/tasks/translation.py
+1
-1
fairseq/trainer.py
fairseq/trainer.py
+10
-10
fairseq/utils.py
fairseq/utils.py
+1
-1
No files found.
fairseq/data/dictionary.py
View file @
bd4db8fb
...
@@ -106,7 +106,7 @@ class Dictionary(object):
...
@@ -106,7 +106,7 @@ class Dictionary(object):
multiple of 8, which is important on some hardware (e.g., Nvidia
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
Tensor Cores).
"""
"""
if
nwords
=
=
-
1
:
if
nwords
<
=
0
:
nwords
=
len
(
self
)
nwords
=
len
(
self
)
new_indices
=
dict
(
zip
(
self
.
symbols
[:
self
.
nspecial
],
range
(
self
.
nspecial
)))
new_indices
=
dict
(
zip
(
self
.
symbols
[:
self
.
nspecial
],
range
(
self
.
nspecial
)))
...
@@ -133,7 +133,7 @@ class Dictionary(object):
...
@@ -133,7 +133,7 @@ class Dictionary(object):
i
+=
1
i
+=
1
threshold_nwords
+=
1
threshold_nwords
+=
1
assert
min
(
new_count
[
self
.
nspecial
:])
>=
threshold
assert
len
(
new_count
)
==
self
.
nspecial
or
min
(
new_count
[
self
.
nspecial
:])
>=
threshold
assert
len
(
new_symbols
)
%
padding_factor
==
0
assert
len
(
new_symbols
)
%
padding_factor
==
0
assert
len
(
new_symbols
)
==
len
(
new_indices
)
assert
len
(
new_symbols
)
==
len
(
new_indices
)
...
@@ -187,12 +187,12 @@ class Dictionary(object):
...
@@ -187,12 +187,12 @@ class Dictionary(object):
d
.
count
.
append
(
count
)
d
.
count
.
append
(
count
)
return
d
return
d
def
save
(
self
,
f
,
threshold
=
3
,
nwords
=-
1
):
def
save
(
self
,
f
):
"""Stores dictionary into a text file"""
"""Stores dictionary into a text file"""
if
isinstance
(
f
,
str
):
if
isinstance
(
f
,
str
):
os
.
makedirs
(
os
.
path
.
dirname
(
f
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
dirname
(
f
),
exist_ok
=
True
)
with
open
(
f
,
'w'
,
encoding
=
'utf-8'
)
as
fd
:
with
open
(
f
,
'w'
,
encoding
=
'utf-8'
)
as
fd
:
return
self
.
save
(
fd
,
threshold
,
nwords
)
return
self
.
save
(
fd
)
for
symbol
,
count
in
zip
(
self
.
symbols
[
self
.
nspecial
:],
self
.
count
[
self
.
nspecial
:]):
for
symbol
,
count
in
zip
(
self
.
symbols
[
self
.
nspecial
:],
self
.
count
[
self
.
nspecial
:]):
print
(
'{} {}'
.
format
(
symbol
,
count
),
file
=
f
)
print
(
'{} {}'
.
format
(
symbol
,
count
),
file
=
f
)
...
...
fairseq/data/indexed_dataset.py
View file @
bd4db8fb
...
@@ -52,8 +52,9 @@ def data_file_path(prefix_path):
...
@@ -52,8 +52,9 @@ def data_file_path(prefix_path):
class
IndexedDataset
(
torch
.
utils
.
data
.
Dataset
):
class
IndexedDataset
(
torch
.
utils
.
data
.
Dataset
):
"""Loader for TorchNet IndexedDataset"""
"""Loader for TorchNet IndexedDataset"""
def
__init__
(
self
,
path
):
def
__init__
(
self
,
path
,
fix_lua_indexing
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
fix_lua_indexing
=
fix_lua_indexing
with
open
(
index_file_path
(
path
),
'rb'
)
as
f
:
with
open
(
index_file_path
(
path
),
'rb'
)
as
f
:
magic
=
f
.
read
(
8
)
magic
=
f
.
read
(
8
)
assert
magic
==
b
'TNTIDX
\x00\x00
'
assert
magic
==
b
'TNTIDX
\x00\x00
'
...
@@ -83,7 +84,10 @@ class IndexedDataset(torch.utils.data.Dataset):
...
@@ -83,7 +84,10 @@ class IndexedDataset(torch.utils.data.Dataset):
a
=
np
.
empty
(
tensor_size
,
dtype
=
self
.
dtype
)
a
=
np
.
empty
(
tensor_size
,
dtype
=
self
.
dtype
)
self
.
data_file
.
seek
(
self
.
data_offsets
[
i
]
*
self
.
element_size
)
self
.
data_file
.
seek
(
self
.
data_offsets
[
i
]
*
self
.
element_size
)
self
.
data_file
.
readinto
(
a
)
self
.
data_file
.
readinto
(
a
)
return
torch
.
from_numpy
(
a
).
long
()
-
1
# subtract 1 for 0-based indexing
item
=
torch
.
from_numpy
(
a
).
long
()
if
self
.
fix_lua_indexing
:
item
-=
1
# subtract 1 for 0-based indexing
return
item
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
size
return
self
.
size
...
@@ -104,6 +108,7 @@ class IndexedInMemoryDataset(IndexedDataset):
...
@@ -104,6 +108,7 @@ class IndexedInMemoryDataset(IndexedDataset):
self
.
buffer
=
np
.
empty
(
self
.
data_offsets
[
-
1
],
dtype
=
self
.
dtype
)
self
.
buffer
=
np
.
empty
(
self
.
data_offsets
[
-
1
],
dtype
=
self
.
dtype
)
self
.
data_file
.
readinto
(
self
.
buffer
)
self
.
data_file
.
readinto
(
self
.
buffer
)
self
.
data_file
.
close
()
self
.
data_file
.
close
()
if
self
.
fix_lua_indexing
:
self
.
buffer
-=
1
# subtract 1 for 0-based indexing
self
.
buffer
-=
1
# subtract 1 for 0-based indexing
def
__del__
(
self
):
def
__del__
(
self
):
...
...
fairseq/fp16_trainer.py
View file @
bd4db8fb
...
@@ -73,7 +73,7 @@ class FP16Trainer(Trainer):
...
@@ -73,7 +73,7 @@ class FP16Trainer(Trainer):
self
.
fp32_params
.
grad
=
self
.
fp32_params
.
data
.
new
(
total_param_size
)
self
.
fp32_params
.
grad
=
self
.
fp32_params
.
data
.
new
(
total_param_size
)
# create optimizer using the copied FP32 params
# create optimizer using the copied FP32 params
self
.
optimizer
=
optim
.
build_optimizer
(
self
.
args
,
[
self
.
fp32_params
])
self
.
_
optimizer
=
optim
.
build_optimizer
(
self
.
args
,
[
self
.
fp32_params
])
self
.
lr_scheduler
=
lr_scheduler
.
build_lr_scheduler
(
self
.
args
,
self
.
optimizer
)
self
.
lr_scheduler
=
lr_scheduler
.
build_lr_scheduler
(
self
.
args
,
self
.
optimizer
)
def
save_checkpoint
(
self
,
filename
,
extra_state
):
def
save_checkpoint
(
self
,
filename
,
extra_state
):
...
...
fairseq/optim/lr_scheduler/fixed_schedule.py
View file @
bd4db8fb
...
@@ -15,6 +15,9 @@ class FixedSchedule(FairseqLRScheduler):
...
@@ -15,6 +15,9 @@ class FixedSchedule(FairseqLRScheduler):
def
__init__
(
self
,
args
,
optimizer
):
def
__init__
(
self
,
args
,
optimizer
):
super
().
__init__
(
args
,
optimizer
)
super
().
__init__
(
args
,
optimizer
)
# set defaults
args
.
warmup_updates
=
getattr
(
args
,
'warmup_updates'
,
0
)
self
.
lr
=
args
.
lr
[
0
]
self
.
lr
=
args
.
lr
[
0
]
if
args
.
warmup_updates
>
0
:
if
args
.
warmup_updates
>
0
:
self
.
warmup_factor
=
1.
/
args
.
warmup_updates
self
.
warmup_factor
=
1.
/
args
.
warmup_updates
...
...
fairseq/tasks/language_modeling.py
View file @
bd4db8fb
...
@@ -50,7 +50,7 @@ class LanguageModelingTask(FairseqTask):
...
@@ -50,7 +50,7 @@ class LanguageModelingTask(FairseqTask):
ds
=
IndexedRawTextDataset
(
path
,
self
.
dictionary
)
ds
=
IndexedRawTextDataset
(
path
,
self
.
dictionary
)
tokens
=
ds
.
tokens_list
tokens
=
ds
.
tokens_list
elif
not
self
.
args
.
raw_text
and
IndexedInMemoryDataset
.
exists
(
path
):
elif
not
self
.
args
.
raw_text
and
IndexedInMemoryDataset
.
exists
(
path
):
ds
=
IndexedInMemoryDataset
(
path
)
ds
=
IndexedInMemoryDataset
(
path
,
fix_lua_indexing
=
True
)
tokens
=
ds
.
buffer
tokens
=
ds
.
buffer
else
:
else
:
raise
FileNotFoundError
(
'Dataset not found: {} ({})'
.
format
(
split
,
self
.
args
.
data
))
raise
FileNotFoundError
(
'Dataset not found: {} ({})'
.
format
(
split
,
self
.
args
.
data
))
...
...
fairseq/tasks/translation.py
View file @
bd4db8fb
...
@@ -89,7 +89,7 @@ class TranslationTask(FairseqTask):
...
@@ -89,7 +89,7 @@ class TranslationTask(FairseqTask):
if
self
.
args
.
raw_text
:
if
self
.
args
.
raw_text
:
return
IndexedRawTextDataset
(
path
,
dictionary
)
return
IndexedRawTextDataset
(
path
,
dictionary
)
elif
IndexedInMemoryDataset
.
exists
(
path
):
elif
IndexedInMemoryDataset
.
exists
(
path
):
return
IndexedInMemoryDataset
(
path
)
return
IndexedInMemoryDataset
(
path
,
fix_lua_indexing
=
True
)
return
None
return
None
src_dataset
=
indexed_dataset
(
prefix
+
src
,
self
.
src_dict
)
src_dataset
=
indexed_dataset
(
prefix
+
src
,
self
.
src_dict
)
...
...
fairseq/trainer.py
View file @
bd4db8fb
...
@@ -40,8 +40,6 @@ class Trainer(object):
...
@@ -40,8 +40,6 @@ class Trainer(object):
self
.
model
=
model
.
cuda
()
self
.
model
=
model
.
cuda
()
self
.
criterion
=
criterion
.
cuda
()
self
.
criterion
=
criterion
.
cuda
()
self
.
optimizer
=
None
# initialize meters
# initialize meters
self
.
meters
=
OrderedDict
()
self
.
meters
=
OrderedDict
()
self
.
meters
[
'train_loss'
]
=
AverageMeter
()
self
.
meters
[
'train_loss'
]
=
AverageMeter
()
...
@@ -61,10 +59,17 @@ class Trainer(object):
...
@@ -61,10 +59,17 @@ class Trainer(object):
self
.
_flat_grads
=
None
self
.
_flat_grads
=
None
self
.
_num_updates
=
0
self
.
_num_updates
=
0
self
.
_optim_history
=
None
self
.
_optim_history
=
None
self
.
_optimizer
=
None
@
property
def
optimizer
(
self
):
if
self
.
_optimizer
is
None
:
self
.
_build_optimizer
()
return
self
.
_optimizer
def
_build_optimizer
(
self
):
def
_build_optimizer
(
self
):
self
.
optimizer
=
optim
.
build_optimizer
(
self
.
args
,
self
.
model
.
parameters
())
self
.
_
optimizer
=
optim
.
build_optimizer
(
self
.
args
,
self
.
model
.
parameters
())
self
.
lr_scheduler
=
lr_scheduler
.
build_lr_scheduler
(
self
.
args
,
self
.
optimizer
)
self
.
lr_scheduler
=
lr_scheduler
.
build_lr_scheduler
(
self
.
args
,
self
.
_
optimizer
)
def
save_checkpoint
(
self
,
filename
,
extra_state
):
def
save_checkpoint
(
self
,
filename
,
extra_state
):
"""Save all training state in a checkpoint file."""
"""Save all training state in a checkpoint file."""
...
@@ -93,7 +98,7 @@ class Trainer(object):
...
@@ -93,7 +98,7 @@ class Trainer(object):
self
.
_num_updates
=
last_optim
[
'num_updates'
]
self
.
_num_updates
=
last_optim
[
'num_updates'
]
if
'train_meters'
in
extra_state
:
if
extra_state
is
not
None
and
'train_meters'
in
extra_state
:
self
.
meters
=
extra_state
[
'train_meters'
]
self
.
meters
=
extra_state
[
'train_meters'
]
del
extra_state
[
'train_meters'
]
del
extra_state
[
'train_meters'
]
...
@@ -101,11 +106,6 @@ class Trainer(object):
...
@@ -101,11 +106,6 @@ class Trainer(object):
def
train_step
(
self
,
sample
,
update_params
=
True
):
def
train_step
(
self
,
sample
,
update_params
=
True
):
"""Do forward, backward and parameter update."""
"""Do forward, backward and parameter update."""
if
self
.
optimizer
is
None
:
# initialize optimizer and LR scheduler if hasn't been loaded from the checkpoint
self
.
_build_optimizer
()
# Set seed based on args.seed and the update number so that we get
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
# reproducible results when resuming from checkpoints
seed
=
self
.
args
.
seed
+
self
.
get_num_updates
()
seed
=
self
.
args
.
seed
+
self
.
get_num_updates
()
...
...
fairseq/utils.py
View file @
bd4db8fb
...
@@ -126,7 +126,7 @@ def _upgrade_state_dict(state):
...
@@ -126,7 +126,7 @@ def _upgrade_state_dict(state):
if
'train_iterator'
not
in
state
[
'extra_state'
]:
if
'train_iterator'
not
in
state
[
'extra_state'
]:
state
[
'extra_state'
][
'train_iterator'
]
=
{
state
[
'extra_state'
][
'train_iterator'
]
=
{
'epoch'
:
state
[
'extra_state'
][
'epoch'
],
'epoch'
:
state
[
'extra_state'
][
'epoch'
],
'iterations_in_epoch'
:
0
,
'iterations_in_epoch'
:
state
[
'extra_state'
].
get
(
'batch_offset'
,
0
)
,
}
}
return
state
return
state
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment