Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
94dae690
Unverified
Commit
94dae690
authored
Dec 06, 2017
by
Myle Ott
Committed by
GitHub
Dec 06, 2017
Browse files
Merge pull request #77 from facebookresearch/oss-merge-internal
parents
d74f200a
0a836276
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
193 additions
and
108 deletions
+193
-108
fairseq/models/fairseq_decoder.py
fairseq/models/fairseq_decoder.py
+3
-0
fairseq/models/fairseq_encoder.py
fairseq/models/fairseq_encoder.py
+3
-0
fairseq/models/fairseq_model.py
fairseq/models/fairseq_model.py
+5
-0
fairseq/models/fconv.py
fairseq/models/fconv.py
+13
-2
fairseq/modules/conv_tbc.py
fairseq/modules/conv_tbc.py
+1
-1
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+145
-81
fairseq/options.py
fairseq/options.py
+2
-2
fairseq/utils.py
fairseq/utils.py
+10
-11
train.py
train.py
+11
-11
No files found.
fairseq/models/fairseq_decoder.py
View file @
94dae690
...
...
@@ -18,3 +18,6 @@ class FairseqDecoder(nn.Module):
def
max_positions
(
self
):
"""Maximum input length supported by the decoder."""
raise
NotImplementedError
def
upgrade_state_dict
(
self
,
state_dict
):
return
state_dict
fairseq/models/fairseq_encoder.py
View file @
94dae690
...
...
@@ -18,3 +18,6 @@ class FairseqEncoder(nn.Module):
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
raise
NotImplementedError
def
upgrade_state_dict
(
self
,
state_dict
):
return
state_dict
fairseq/models/fairseq_model.py
View file @
94dae690
...
...
@@ -43,6 +43,11 @@ class FairseqModel(nn.Module):
"""Maximum output length supported by the decoder."""
return
self
.
decoder
.
max_positions
()
def
upgrade_state_dict
(
self
,
state_dict
):
state_dict
=
self
.
encoder
.
upgrade_state_dict
(
state_dict
)
state_dict
=
self
.
decoder
.
upgrade_state_dict
(
state_dict
)
return
state_dict
def
make_generation_fast_
(
self
,
**
kwargs
):
"""Optimize model for faster generation."""
if
self
.
_is_generation_fast
:
...
...
fairseq/models/fconv.py
View file @
94dae690
...
...
@@ -58,7 +58,7 @@ class FConvEncoder(FairseqEncoder):
self
.
projections
=
nn
.
ModuleList
()
self
.
convolutions
=
nn
.
ModuleList
()
for
(
out_channels
,
kernel_size
)
in
convolutions
:
pad
=
(
kernel_size
-
1
)
/
/
2
pad
=
(
kernel_size
-
1
)
/
2
self
.
projections
.
append
(
Linear
(
in_channels
,
out_channels
)
if
in_channels
!=
out_channels
else
None
)
self
.
convolutions
.
append
(
...
...
@@ -154,6 +154,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
max_positions
=
1024
,
convolutions
=
((
512
,
3
),)
*
20
,
attention
=
True
,
dropout
=
0.1
):
super
().
__init__
()
self
.
register_buffer
(
'version'
,
torch
.
Tensor
([
2
]))
self
.
dictionary
=
dictionary
self
.
dropout
=
dropout
...
...
@@ -265,6 +266,16 @@ class FConvDecoder(FairseqIncrementalDecoder):
"""Maximum output length supported by the decoder."""
return
self
.
embed_positions
.
num_embeddings
-
self
.
dictionary
.
pad
()
-
1
def
upgrade_state_dict
(
self
,
state_dict
):
if
state_dict
.
get
(
'decoder.version'
,
torch
.
Tensor
([
1
]))[
0
]
<
2
:
# old models use incorrect weight norm dimension
for
i
,
conv
in
enumerate
(
self
.
convolutions
):
# reconfigure weight norm
nn
.
utils
.
remove_weight_norm
(
conv
)
self
.
convolutions
[
i
]
=
nn
.
utils
.
weight_norm
(
conv
,
dim
=
0
)
state_dict
[
'decoder.version'
]
=
torch
.
Tensor
([
1
])
return
state_dict
def
_split_encoder_out
(
self
,
encoder_out
):
"""Split and transpose encoder outputs.
...
...
@@ -307,7 +318,7 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs
std
=
math
.
sqrt
((
4
*
(
1.0
-
dropout
))
/
(
m
.
kernel_size
[
0
]
*
in_channels
))
m
.
weight
.
data
.
normal_
(
mean
=
0
,
std
=
std
)
m
.
bias
.
data
.
zero_
()
return
nn
.
utils
.
weight_norm
(
m
)
return
nn
.
utils
.
weight_norm
(
m
,
dim
=
2
)
def
ConvTBC
(
in_channels
,
out_channels
,
kernel_size
,
dropout
=
0
,
**
kwargs
):
...
...
fairseq/modules/conv_tbc.py
View file @
94dae690
...
...
@@ -59,7 +59,7 @@ class ConvTBCFunction(Function):
kernel_size
=
weight_size
[
0
]
output
=
input
.
new
(
input_size
[
0
]
-
kernel_size
+
1
+
pad
*
2
,
input_size
[
0
]
-
kernel_size
+
1
+
int
(
pad
*
2
)
,
input_size
[
1
],
weight_size
[
2
])
...
...
fairseq/multiprocessing_trainer.py
View file @
94dae690
...
...
@@ -11,10 +11,11 @@ Train a network on multiple GPUs using multiprocessing.
"""
from
itertools
import
cycle
,
islice
import
math
import
torch
from
torch.optim.lr_scheduler
import
LambdaLR
,
ReduceLROnPlateau
from
fairseq
import
nccl
,
utils
from
fairseq
import
meters
,
nccl
,
utils
from
fairseq.multiprocessing_event_loop
import
MultiprocessingEventLoop
,
Future
from
fairseq.nag
import
NAG
...
...
@@ -67,39 +68,61 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
model
=
model
.
cuda
()
self
.
criterion
=
criterion
.
cuda
()
# initialize optimizer
# initialize optimizer and LR scheduler
self
.
args
.
lr
=
list
(
map
(
float
,
self
.
args
.
lr
.
split
(
','
)))
self
.
optimizer
=
self
.
_build_optimizer
()
self
.
loss
=
None
# initialize LR scheduler
self
.
lr_scheduler
=
self
.
_build_lr_scheduler
()
self
.
loss
=
None
self
.
_max_bsz_seen
=
0
def
_build_optimizer
(
self
):
# When resuming training from a checkpoint, we load the old optimizer
# state that includes things like learning rate, momentum factor, etc.
# We use this dictionary to override values stored in the checkpoint,
# e.g., we might prefer the values specified on the command line.
self
.
_override_optim_state
=
{}
if
self
.
args
.
optimizer
==
'adagrad'
:
return
torch
.
optim
.
Adagrad
(
self
.
model
.
parameters
(),
lr
=
self
.
args
.
lr
,
weight_decay
=
self
.
args
.
weight_decay
)
self
.
_override_optim_state
=
{
'lr'
:
self
.
args
.
lr
[
0
],
'weight_decay'
:
self
.
args
.
weight_decay
,
}
return
torch
.
optim
.
Adagrad
(
self
.
model
.
parameters
(),
**
self
.
_override_optim_state
)
elif
self
.
args
.
optimizer
==
'adam'
:
return
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
lr
=
self
.
args
.
lr
,
betas
=
eval
(
self
.
args
.
adam_betas
),
weight_decay
=
self
.
args
.
weight_decay
)
self
.
_override_optim_state
=
{
'lr'
:
self
.
args
.
lr
[
0
],
'betas'
:
eval
(
self
.
args
.
adam_betas
),
'weight_decay'
:
self
.
args
.
weight_decay
,
}
return
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
**
self
.
_override_optim_state
)
elif
self
.
args
.
optimizer
==
'nag'
:
return
NAG
(
self
.
model
.
parameters
(),
lr
=
self
.
args
.
lr
,
momentum
=
self
.
args
.
momentum
,
weight_decay
=
self
.
args
.
weight_decay
)
self
.
_override_optim_state
=
{
'lr'
:
self
.
args
.
lr
[
0
],
'momentum'
:
self
.
args
.
momentum
,
'weight_decay'
:
self
.
args
.
weight_decay
,
}
return
NAG
(
self
.
model
.
parameters
(),
**
self
.
_override_optim_state
)
elif
self
.
args
.
optimizer
==
'sgd'
:
return
torch
.
optim
.
SGD
(
self
.
model
.
parameters
(),
lr
=
self
.
args
.
lr
,
momentum
=
self
.
args
.
momentum
,
weight_decay
=
self
.
args
.
weight_decay
)
self
.
_override_optim_state
=
{
'lr'
:
self
.
args
.
lr
[
0
],
'momentum'
:
self
.
args
.
momentum
,
'weight_decay'
:
self
.
args
.
weight_decay
,
}
return
torch
.
optim
.
SGD
(
self
.
model
.
parameters
(),
**
self
.
_override_optim_state
)
else
:
raise
ValueError
(
'Unknown optimizer: {}'
.
format
(
self
.
args
.
optimizer
))
def
_build_lr_scheduler
(
self
):
if
self
.
args
.
force_anneal
>
0
:
if
len
(
self
.
args
.
lr
)
>
1
or
self
.
args
.
force_anneal
>
0
:
lrs
=
self
.
args
.
lr
def
anneal
(
e
):
if
e
<
self
.
args
.
force_anneal
:
return
1
# use fixed LR schedule
next_lr
=
lrs
[
min
(
e
,
len
(
lrs
)
-
1
)]
else
:
return
self
.
args
.
lrshrink
**
(
e
+
1
-
self
.
args
.
force_anneal
)
next_lr
=
lrs
[
-
1
]
*
self
.
args
.
lrshrink
**
(
e
+
1
-
self
.
args
.
force_anneal
)
return
next_lr
/
lrs
[
0
]
# correct for scaling from LambdaLR
lr_scheduler
=
LambdaLR
(
self
.
optimizer
,
anneal
)
lr_scheduler
.
best
=
None
else
:
...
...
@@ -134,9 +157,24 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return
extra_state
def
_async_load_checkpoint
(
self
,
rank
,
device_id
,
filename
):
extra_state
,
self
.
_optim_history
=
utils
.
load_state
(
filename
,
self
.
model
,
self
.
criterion
,
self
.
optimizer
,
self
.
lr_scheduler
,
cuda_device
=
device_id
)
extra_state
,
self
.
_optim_history
,
last_optim_state
=
utils
.
load_model_state
(
filename
,
self
.
model
,
cuda_device
=
device_id
)
if
last_optim_state
is
not
None
:
# rebuild optimizer after loading model, since params may have changed
self
.
optimizer
=
self
.
_build_optimizer
()
self
.
lr_scheduler
=
self
.
_build_lr_scheduler
()
# only load optimizer and lr_scheduler if they match the checkpoint
last_optim
=
self
.
_optim_history
[
-
1
]
if
last_optim
[
'criterion_name'
]
==
self
.
criterion
.
__class__
.
__name__
:
self
.
optimizer
.
load_state_dict
(
last_optim_state
)
self
.
lr_scheduler
.
best
=
last_optim
[
'best_loss'
]
# override learning rate, momentum, etc. with latest values
for
group
in
self
.
optimizer
.
param_groups
:
group
.
update
(
self
.
_override_optim_state
)
return
extra_state
def
set_seed
(
self
,
seed
):
...
...
@@ -161,14 +199,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
_scatter_samples
(
samples
,
replace_empty_samples
=
replace_empty_samples
)
# forward pass
sample_sizes
,
logging_outputs
=
Future
.
gen_tuple_list
([
sample_sizes
,
logging_outputs
,
ooms_fwd
=
Future
.
gen_tuple_list
([
self
.
call_async
(
rank
,
'_async_forward'
)
for
rank
in
range
(
self
.
num_replicas
)
])
# backward pass, all-reduce gradients and take an optimization step
grad_denom
=
self
.
criterion
.
__class__
.
grad_denom
(
sample_sizes
)
grad_norms
=
Future
.
gen_list
([
grad_norms
,
ooms_bwd
=
Future
.
gen_
tuple_
list
([
self
.
call_async
(
rank
,
'_async_backward_and_opt'
,
grad_denom
=
grad_denom
)
for
rank
in
range
(
self
.
num_replicas
)
])
...
...
@@ -176,6 +214,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# aggregate logging output
logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
logging_output
[
'gnorm'
]
=
grad_norms
[
0
]
# log the gradient norm
logging_output
[
'oom'
]
=
sum
(
ooms_fwd
)
+
sum
(
ooms_bwd
)
return
logging_output
...
...
@@ -186,34 +225,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
model
.
train
()
self
.
optimizer
.
zero_grad
()
if
self
.
_sample
is
None
:
return
0
,
{}
# calculate loss and sample size
self
.
loss
,
sample_size
,
logging_output
=
self
.
criterion
(
self
.
model
,
self
.
_sample
)
sample_size
,
logging_output
,
oom
=
0
,
{},
False
if
self
.
_sample
is
not
None
:
try
:
# calculate loss and sample size
self
.
loss
,
sample_size
,
logging_output
=
self
.
criterion
(
self
.
model
,
self
.
_sample
)
except
RuntimeError
as
e
:
if
not
eval
and
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory on GPU #{}, skipping batch'
.
format
(
device_id
))
oom
=
True
self
.
loss
=
None
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
torch
.
cuda
.
empty_cache
()
else
:
raise
e
return
sample_size
,
logging_output
return
sample_size
,
logging_output
,
oom
def
_async_backward_and_opt
(
self
,
rank
,
device_id
,
grad_denom
):
oom
=
False
if
self
.
loss
is
not
None
:
# backward pass
self
.
loss
.
backward
()
# get model parameters as a flattened (contiguous) tensor
flat_grads
=
self
.
_flat_model_grads
()
# all-reduce grads
nccl
.
all_reduce
(
flat_grads
)
try
:
# backward pass
self
.
loss
.
backward
()
except
RuntimeError
as
e
:
if
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory on GPU #{}, skipping batch'
.
format
(
device_id
))
oom
=
True
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
torch
.
cuda
.
empty_cache
()
self
.
optimizer
.
zero_grad
()
else
:
raise
e
# normalize grads
if
grad_denom
!=
0
:
flat_grads
.
div_
(
grad_denom
)
# all-reduce grads and rescale by grad_denom
self
.
_all_reduce_and_rescale_grads
(
grad_denom
)
# clip grads
grad_norm
=
self
.
_clip_grads_
(
flat_grads
,
self
.
args
.
clip_norm
)
# copy reduced grads back
self
.
_set_model_grads_
(
flat_grads
)
grad_norm
=
torch
.
nn
.
utils
.
clip_grad_norm
(
self
.
model
.
parameters
(),
self
.
args
.
clip_norm
)
# take an optimization step
self
.
optimizer
.
step
()
...
...
@@ -221,41 +270,49 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# reset loss
self
.
loss
=
None
return
grad_norm
def
_model_grads
(
self
):
return
[
p
.
grad
for
p
in
self
.
model
.
parameters
()
if
p
.
requires_grad
]
def
_flat_model_grads
(
self
):
grads
=
self
.
_model_grads
()
if
not
hasattr
(
self
,
'_flat_grads'
):
num_params
=
sum
(
g
.
data
.
numel
()
for
g
in
grads
)
self
.
_flat_grads
=
grads
[
0
].
data
.
new
(
num_params
)
offset
=
0
for
grad
in
grads
:
grad
=
grad
.
data
.
view
(
-
1
)
numel
=
grad
.
numel
()
self
.
_flat_grads
[
offset
:
offset
+
numel
].
copy_
(
grad
)
offset
+=
numel
return
self
.
_flat_grads
def
_set_model_grads_
(
self
,
flat_grads
):
grads
=
self
.
_model_grads
()
offset
=
0
for
grad
in
grads
:
grad
=
grad
.
data
.
view
(
-
1
)
numel
=
grad
.
numel
()
grad
.
copy_
(
flat_grads
[
offset
:
offset
+
numel
])
offset
+=
numel
assert
offset
==
flat_grads
.
numel
()
def
_clip_grads_
(
self
,
flat_grads
,
clipv
):
"""nn.utils.clip_grad_norm for flattened (contiguous) tensors."""
norm
=
flat_grads
.
norm
()
if
clipv
>
0
and
norm
>
clipv
:
coef
=
max
(
norm
,
1e-6
)
/
clipv
flat_grads
.
div_
(
coef
)
return
norm
return
grad_norm
,
oom
def
_all_reduce_and_rescale_grads
(
self
,
grad_denom
,
buffer_size
=
10485760
):
"""All-reduce and rescale gradients in chunks of the specified size."""
grads
=
[
p
.
grad
.
data
for
p
in
self
.
model
.
parameters
()
if
p
.
requires_grad
]
buffer_t
=
grads
[
0
].
new
(
math
.
ceil
(
buffer_size
/
grads
[
0
].
element_size
())).
zero_
()
buffer
=
[]
def
all_reduce_buffer
():
# copy grads into buffer_t
offset
=
0
for
g
in
buffer
:
numel
=
g
.
numel
()
buffer_t
[
offset
:
offset
+
numel
].
copy_
(
g
.
view
(
-
1
))
offset
+=
numel
# all-reduce and rescale
nccl
.
all_reduce
(
buffer_t
[:
offset
])
buffer_t
.
div_
(
grad_denom
)
# copy all-reduced buffer back into grads
offset
=
0
for
g
in
buffer
:
numel
=
g
.
numel
()
g
.
view
(
-
1
).
copy_
(
buffer_t
[
offset
:
offset
+
numel
])
offset
+=
numel
filled
=
0
for
g
in
grads
:
sz
=
g
.
numel
()
*
g
.
element_size
()
if
sz
>
buffer_size
:
# grad is bigger than buffer, all-reduce and rescale directly
nccl
.
all_reduce
(
g
)
g
.
div_
(
grad_denom
)
elif
filled
+
sz
>
buffer_size
:
# buffer is full, all-reduce and replace buffer with grad
all_reduce_buffer
()
buffer
=
[
g
]
filled
=
sz
else
:
# add grad to buffer
buffer
.
append
(
g
)
filled
+=
sz
if
len
(
buffer
)
>
0
:
all_reduce_buffer
()
def
valid_step
(
self
,
samples
):
"""Do forward pass in parallel."""
...
...
@@ -263,10 +320,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
_scatter_samples
(
samples
,
volatile
=
True
)
# forward pass
_sample_sizes
,
logging_outputs
=
Future
.
gen_tuple_list
([
_sample_sizes
,
logging_outputs
,
ooms_fwd
=
Future
.
gen_tuple_list
([
self
.
call_async
(
rank
,
'_async_forward'
,
eval
=
True
)
for
rank
in
range
(
self
.
num_replicas
)
])
assert
sum
(
ooms_fwd
)
==
0
# aggregate logging output
logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
...
...
@@ -314,4 +372,10 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
if
sample
is
None
:
self
.
_sample
=
None
else
:
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
# clear the caching allocator if this is the largest sample we've seen
if
sample
[
'target'
].
size
(
0
)
>
self
.
_max_bsz_seen
:
self
.
_max_bsz_seen
=
sample
[
'target'
].
size
(
0
)
torch
.
cuda
.
empty_cache
()
self
.
_sample
=
utils
.
prepare_sample
(
sample
,
volatile
=
volatile
,
cuda_device
=
device_id
)
fairseq/options.py
View file @
94dae690
...
...
@@ -49,8 +49,8 @@ def add_optimization_args(parser):
group
.
add_argument
(
'--optimizer'
,
default
=
'nag'
,
metavar
=
'OPT'
,
choices
=
MultiprocessingTrainer
.
OPTIMIZERS
,
help
=
'optimizer ({})'
.
format
(
', '
.
join
(
MultiprocessingTrainer
.
OPTIMIZERS
)))
group
.
add_argument
(
'--lr'
,
'--learning-rate'
,
default
=
0.25
,
type
=
float
,
metavar
=
'LR'
,
help
=
'
initial
learning rate'
)
group
.
add_argument
(
'--lr'
,
'--learning-rate'
,
default
=
'
0.25
'
,
metavar
=
'LR
1,LR2,...,LRn
'
,
help
=
'learning rate
for the first n epochs with all epochs >n using LRn
'
)
group
.
add_argument
(
'--min-lr'
,
metavar
=
'LR'
,
default
=
1e-5
,
type
=
float
,
help
=
'minimum learning rate'
)
group
.
add_argument
(
'--force-anneal'
,
'--fa'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
...
...
fairseq/utils.py
View file @
94dae690
...
...
@@ -83,9 +83,9 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, optim_
torch_persistent_save
(
state_dict
,
filename
)
def
load_state
(
filename
,
model
,
criterion
,
optimizer
,
lr_scheduler
,
cuda_device
=
None
):
def
load_
model_
state
(
filename
,
model
,
cuda_device
=
None
):
if
not
os
.
path
.
exists
(
filename
):
return
None
,
[]
return
None
,
[]
,
None
if
cuda_device
is
None
:
state
=
torch
.
load
(
filename
)
else
:
...
...
@@ -94,18 +94,16 @@ def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=
map_location
=
lambda
s
,
l
:
default_restore_location
(
s
,
'cuda:{}'
.
format
(
cuda_device
))
)
state
=
_upgrade_state_dict
(
state
)
state
[
'model'
]
=
model
.
upgrade_state_dict
(
state
[
'model'
])
# load model parameters
model
.
load_state_dict
(
state
[
'model'
])
# only load optimizer and lr_scheduler if they match with the checkpoint
optim_history
=
state
[
'optimizer_history'
]
last_optim
=
optim_history
[
-
1
]
if
last_optim
[
'criterion_name'
]
==
criterion
.
__class__
.
__name__
:
optimizer
.
load_state_dict
(
state
[
'last_optimizer_state'
])
lr_scheduler
.
best
=
last_optim
[
'best_loss'
]
try
:
model
.
load_state_dict
(
state
[
'model'
])
except
:
raise
Exception
(
'Cannot load model parameters from checkpoint, '
'please ensure that the architectures match'
)
return
state
[
'extra_state'
],
optim_history
return
state
[
'extra_state'
],
state
[
'
optim
izer
_history
'
],
state
[
'last_optimizer_state'
]
def
_upgrade_state_dict
(
state
):
...
...
@@ -164,6 +162,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di
ensemble
=
[]
for
state
in
states
:
model
=
build_model
(
args
,
src_dict
,
dst_dict
)
state
[
'model'
]
=
model
.
upgrade_state_dict
(
state
[
'model'
])
model
.
load_state_dict
(
state
[
'model'
])
ensemble
.
append
(
model
)
return
ensemble
,
args
...
...
train.py
View file @
94dae690
...
...
@@ -53,18 +53,18 @@ def main():
# record inferred languages in args, so that it's saved in checkpoints
args
.
source_lang
,
args
.
target_lang
=
dataset
.
src
,
dataset
.
dst
if
not
torch
.
cuda
.
is_available
():
raise
NotImplementedError
(
'Training on CPU is not supported'
)
args
.
num_gpus
=
torch
.
cuda
.
device_count
()
print
(
args
)
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
src
,
len
(
dataset
.
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
dst
,
len
(
dataset
.
dst_dict
)))
for
split
in
splits
:
print
(
'| {} {} {} examples'
.
format
(
args
.
data
,
split
,
len
(
dataset
.
splits
[
split
])))
if
not
torch
.
cuda
.
is_available
():
raise
NotImplementedError
(
'Training on CPU is not supported'
)
num_gpus
=
torch
.
cuda
.
device_count
()
print
(
'| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'
.
format
(
num_gpus
,
args
.
max_tokens
,
args
.
max_sentences
))
args
.
num_gpus
,
args
.
max_tokens
,
args
.
max_sentences
))
# Build model and criterion
model
=
utils
.
build_model
(
args
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
...
...
@@ -102,11 +102,11 @@ def main():
train_meter
.
start
()
while
lr
>
args
.
min_lr
and
epoch
<=
max_epoch
:
# train for one epoch
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
max_positions_train
,
num_gpus
)
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
max_positions_train
)
# evaluate on validate set
for
k
,
subset
in
enumerate
(
args
.
valid_subset
.
split
(
','
)):
val_loss
=
validate
(
args
,
epoch
,
trainer
,
dataset
,
max_positions_valid
,
subset
,
num_gpus
)
val_loss
=
validate
(
args
,
epoch
,
trainer
,
dataset
,
max_positions_valid
,
subset
)
if
k
==
0
:
if
not
args
.
no_save
:
# save checkpoint
...
...
@@ -130,7 +130,7 @@ def get_perplexity(loss):
return
float
(
'inf'
)
def
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
max_positions
,
num_gpus
):
def
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
max_positions
):
"""Train the model for one epoch."""
seed
=
args
.
seed
+
epoch
...
...
@@ -152,7 +152,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
lr
=
trainer
.
get_lr
()
with
utils
.
build_progress_bar
(
args
,
itr
,
epoch
)
as
t
:
for
i
,
sample
in
data
.
skip_group_enumerator
(
t
,
num_gpus
,
batch_offset
):
for
i
,
sample
in
data
.
skip_group_enumerator
(
t
,
args
.
num_gpus
,
batch_offset
):
loss_dict
=
trainer
.
train_step
(
sample
)
loss
=
loss_dict
[
'loss'
]
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
...
...
@@ -222,7 +222,7 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
trainer
.
save_checkpoint
(
last_filename
,
extra_state
)
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
max_positions
,
subset
,
ngpus
):
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
max_positions
,
subset
):
"""Evaluate the model on the validation set and return the average loss."""
itr
=
dataset
.
eval_dataloader
(
...
...
@@ -236,7 +236,7 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
prefix
=
'valid on
\'
{}
\'
subset'
.
format
(
subset
)
with
utils
.
build_progress_bar
(
args
,
itr
,
epoch
,
prefix
)
as
t
:
for
_
,
sample
in
data
.
skip_group_enumerator
(
t
,
n
gpus
):
for
_
,
sample
in
data
.
skip_group_enumerator
(
t
,
args
.
num_
gpus
):
loss_dict
=
trainer
.
valid_step
(
sample
)
loss
=
loss_dict
[
'loss'
]
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment