Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
d9f46c54
Commit
d9f46c54
authored
Jan 26, 2018
by
Sergey Edunov
Browse files
Merge branch 'master' of github.com:facebookresearch/fairseq-py into prepare_wmt
parents
4185d3ed
ee36a6f3
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
351 additions
and
129 deletions
+351
-129
README.md
README.md
+2
-2
fairseq/clib/temporal_convolution_tbc/temporal_convolution_tbc.cpp
...lib/temporal_convolution_tbc/temporal_convolution_tbc.cpp
+1
-1
fairseq/criterions/cross_entropy.py
fairseq/criterions/cross_entropy.py
+13
-6
fairseq/criterions/fairseq_criterion.py
fairseq/criterions/fairseq_criterion.py
+1
-1
fairseq/criterions/label_smoothed_cross_entropy.py
fairseq/criterions/label_smoothed_cross_entropy.py
+17
-7
fairseq/data.py
fairseq/data.py
+29
-9
fairseq/dictionary.py
fairseq/dictionary.py
+2
-0
fairseq/models/fairseq_decoder.py
fairseq/models/fairseq_decoder.py
+12
-1
fairseq/models/fairseq_encoder.py
fairseq/models/fairseq_encoder.py
+2
-1
fairseq/models/fairseq_incremental_decoder.py
fairseq/models/fairseq_incremental_decoder.py
+10
-7
fairseq/models/fairseq_model.py
fairseq/models/fairseq_model.py
+9
-5
fairseq/models/fconv.py
fairseq/models/fconv.py
+43
-56
fairseq/models/lstm.py
fairseq/models/lstm.py
+3
-5
fairseq/modules/__init__.py
fairseq/modules/__init__.py
+2
-0
fairseq/modules/conv_tbc.py
fairseq/modules/conv_tbc.py
+6
-4
fairseq/modules/learned_positional_embedding.py
fairseq/modules/learned_positional_embedding.py
+57
-0
fairseq/modules/linearized_convolution.py
fairseq/modules/linearized_convolution.py
+6
-2
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+26
-19
fairseq/optim/adam.py
fairseq/optim/adam.py
+103
-0
fairseq/optim/nag.py
fairseq/optim/nag.py
+7
-3
No files found.
README.md
View file @
d9f46c54
...
...
@@ -24,8 +24,8 @@ If you use the code in your paper, then please cite it as:
*
Python version 3.6
*
A
[
PyTorch installation
](
http://pytorch.org/
)
Currently fairseq-py requires
installing PyTorch from source
.
Please follow the instructions here: https://github.com/pytorch/pytorch#
from-source
.
Currently fairseq-py requires
PyTorch version >= 0.3.0
.
Please follow the instructions here: https://github.com/pytorch/pytorch#
installation
.
If you use Docker make sure to increase the shared memory size either with
`--ipc=host`
or
`--shm-size`
as command line
options to
`nvidia-docker run`
.
...
...
fairseq/clib/temporal_convolution_tbc/temporal_convolution_tbc.cpp
View file @
d9f46c54
...
...
@@ -126,5 +126,5 @@ void TemporalConvolutionTBC_backward(
}
auto
tmp
=
dOutput
.
sum
(
0
,
false
);
dBias
.
assign
_
(
tmp
.
sum
(
0
));
dBias
.
copy
_
(
tmp
.
sum
(
0
));
}
fairseq/criterions/cross_entropy.py
View file @
d9f46c54
...
...
@@ -17,7 +17,7 @@ class CrossEntropyCriterion(FairseqCriterion):
def
__init__
(
self
,
args
,
dst_dict
):
super
().
__init__
(
args
,
dst_dict
)
def
forward
(
self
,
model
,
sample
):
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
...
...
@@ -26,12 +26,14 @@ class CrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training
"""
net_output
=
model
(
**
sample
[
'net_input'
])
input
=
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
)
)
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
)
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
F
.
cross_entropy
(
input
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
)
loss
=
F
.
nll_loss
(
lprobs
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
,
reduce
=
reduce
)
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
logging_output
=
{
'loss'
:
loss
.
data
[
0
],
'loss'
:
loss
.
data
[
0
]
if
reduce
else
loss
.
data
,
'ntokens'
:
sample
[
'ntokens'
],
'sample_size'
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
...
...
@@ -39,7 +41,12 @@ class CrossEntropyCriterion(FairseqCriterion):
@
staticmethod
def
aggregate_logging_outputs
(
logging_outputs
):
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
sum
(
log
.
get
(
'loss'
,
0
)
for
log
in
logging_outputs
)
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
'sample_size'
,
0
)
for
log
in
logging_outputs
)
return
{
'loss'
:
sum
(
log
.
get
(
'loss'
,
0
)
for
log
in
logging_outputs
)
/
sample_size
/
math
.
log
(
2
),
agg_output
=
{
'loss'
:
loss_sum
/
sample_size
/
math
.
log
(
2
),
}
if
sample_size
!=
ntokens
:
agg_output
[
'nll_loss'
]
=
loss_sum
/
ntokens
/
math
.
log
(
2
)
return
agg_output
fairseq/criterions/fairseq_criterion.py
View file @
d9f46c54
...
...
@@ -16,7 +16,7 @@ class FairseqCriterion(_Loss):
self
.
args
=
args
self
.
padding_idx
=
dst_dict
.
pad
()
def
forward
(
self
,
model
,
sample
):
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
...
...
fairseq/criterions/label_smoothed_cross_entropy.py
View file @
d9f46c54
...
...
@@ -11,13 +11,15 @@ import torch
from
torch.autograd.variable
import
Variable
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
.fairseq_criterion
import
FairseqCriterion
class
LabelSmoothedNLLLoss
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
target
,
eps
,
padding_idx
,
weights
):
def
forward
(
ctx
,
input
,
target
,
eps
,
padding_idx
,
weights
,
reduce
=
True
):
grad_input
=
input
.
new
(
input
.
size
()).
zero_
()
target
=
target
.
view
(
target
.
size
(
0
),
1
)
grad_input
=
grad_input
.
scatter_
(
grad_input
.
dim
()
-
1
,
target
,
eps
-
1
)
...
...
@@ -34,11 +36,14 @@ class LabelSmoothedNLLLoss(torch.autograd.Function):
grad_input
=
grad_input
.
add
(
-
eps
/
norm
)
ctx
.
grad_input
=
grad_input
return
input
.
new
([
grad_input
.
view
(
-
1
).
dot
(
input
.
view
(
-
1
))])
if
reduce
:
return
input
.
new
([
grad_input
.
view
(
-
1
).
dot
(
input
.
view
(
-
1
))])
else
:
return
grad_input
*
input
@
staticmethod
def
backward
(
ctx
,
grad
):
return
V
ariable
(
ctx
.
grad_input
,
volatile
=
True
)
*
grad
,
None
,
None
,
None
,
None
return
utils
.
volatile_v
ariable
(
ctx
.
grad_input
)
*
grad
,
None
,
None
,
None
,
None
,
None
class
LabelSmoothedCrossEntropyCriterion
(
FairseqCriterion
):
...
...
@@ -48,7 +53,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self
.
eps
=
args
.
label_smoothing
self
.
weights
=
weights
def
forward
(
self
,
model
,
sample
):
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
...
...
@@ -57,12 +62,15 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training
"""
net_output
=
model
(
**
sample
[
'net_input'
])
input
=
F
.
log_softmax
(
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
))
)
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
)
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
LabelSmoothedNLLLoss
.
apply
(
input
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
)
loss
=
LabelSmoothedNLLLoss
.
apply
(
lprobs
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
,
reduce
)
nll_loss
=
F
.
nll_loss
(
lprobs
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
,
reduce
=
reduce
)
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
logging_output
=
{
'loss'
:
loss
.
data
[
0
],
'loss'
:
loss
.
data
[
0
]
if
reduce
else
loss
.
data
,
'nll_loss'
:
nll_loss
.
data
[
0
]
if
reduce
else
loss
.
data
,
'ntokens'
:
sample
[
'ntokens'
],
'sample_size'
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
...
...
@@ -70,7 +78,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
@
staticmethod
def
aggregate_logging_outputs
(
logging_outputs
):
"""Aggregate logging outputs from data parallel training."""
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
'sample_size'
,
0
)
for
log
in
logging_outputs
)
return
{
'loss'
:
sum
(
log
.
get
(
'loss'
,
0
)
for
log
in
logging_outputs
)
/
sample_size
/
math
.
log
(
2
),
'nll_loss'
:
sum
(
log
.
get
(
'nll_loss'
,
0
)
for
log
in
logging_outputs
)
/
ntokens
/
math
.
log
(
2
),
}
fairseq/data.py
View file @
d9f46c54
...
...
@@ -175,6 +175,23 @@ def skip_group_enumerator(it, ngpus, offset=0):
yield
(
idx
,
res
)
class
sharded_iterator
(
object
):
def
__init__
(
self
,
itr
,
num_shards
,
shard_id
):
assert
shard_id
>=
0
and
shard_id
<
num_shards
self
.
itr
=
itr
self
.
num_shards
=
num_shards
self
.
shard_id
=
shard_id
def
__len__
(
self
):
return
len
(
self
.
itr
)
def
__iter__
(
self
):
for
i
,
v
in
enumerate
(
self
.
itr
):
if
i
%
self
.
num_shards
==
self
.
shard_id
:
yield
v
class
LanguagePairDataset
(
object
):
# padding constants
...
...
@@ -212,13 +229,15 @@ class LanguagePairDataset(object):
return
{
'id'
:
torch
.
LongTensor
([
s
[
'id'
].
item
()
for
s
in
samples
]),
'src_tokens'
:
merge
(
'source'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_SOURCE
),
# we create a shifted version of targets for feeding the previous
# output token(s) into the next decoder step
'input_tokens'
:
merge
(
'target'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
,
move_eos_to_beginning
=
True
),
'target'
:
merge
(
'target'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
),
'ntokens'
:
sum
(
len
(
s
[
'target'
])
for
s
in
samples
),
'net_input'
:
{
'src_tokens'
:
merge
(
'source'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_SOURCE
),
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
'input_tokens'
:
merge
(
'target'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
,
move_eos_to_beginning
=
True
),
},
'target'
:
merge
(
'target'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
),
}
@
staticmethod
...
...
@@ -278,9 +297,10 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
if
ignore_invalid_inputs
:
ignored
.
append
(
idx
)
continue
raise
Exception
(
"Unable to handle input id {} of size {} / {}."
.
format
(
idx
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
]))
raise
Exception
((
"Sample #{} has size (src={}, dst={}) but max size is {}."
" Skip this example with --skip-invalid-size-inputs-valid-test"
).
format
(
idx
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
],
max_positions
))
sample_len
=
max
(
sample_len
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
])
num_tokens
=
(
len
(
batch
)
+
1
)
*
sample_len
...
...
fairseq/dictionary.py
View file @
d9f46c54
...
...
@@ -113,6 +113,8 @@ class Dictionary(object):
try
:
with
open
(
f
,
'r'
,
encoding
=
'utf-8'
)
as
fd
:
return
Dictionary
.
load
(
fd
)
except
FileNotFoundError
as
fnfe
:
raise
fnfe
except
:
raise
Exception
(
"Incorrect encoding detected in {}, please "
"rebuild the dataset"
.
format
(
f
))
...
...
fairseq/models/fairseq_decoder.py
View file @
d9f46c54
...
...
@@ -7,13 +7,24 @@
#
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
FairseqDecoder
(
nn
.
Module
):
"""Base class for decoders."""
def
__init__
(
self
):
def
__init__
(
self
,
dictionary
):
super
().
__init__
()
self
.
dictionary
=
dictionary
def
get_normalized_probs
(
self
,
net_output
,
log_probs
):
"""Get normalized probabilities (or log probs) from a net's output."""
vocab
=
net_output
.
size
(
-
1
)
net_output1
=
net_output
.
view
(
-
1
,
vocab
)
if
log_probs
:
return
F
.
log_softmax
(
net_output1
,
dim
=
1
).
view_as
(
net_output
)
else
:
return
F
.
softmax
(
net_output1
,
dim
=
1
).
view_as
(
net_output
)
def
max_positions
(
self
):
"""Maximum input length supported by the decoder."""
...
...
fairseq/models/fairseq_encoder.py
View file @
d9f46c54
...
...
@@ -12,8 +12,9 @@ import torch.nn as nn
class
FairseqEncoder
(
nn
.
Module
):
"""Base class for encoders."""
def
__init__
(
self
):
def
__init__
(
self
,
dictionary
):
super
().
__init__
()
self
.
dictionary
=
dictionary
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
...
...
fairseq/models/fairseq_incremental_decoder.py
View file @
d9f46c54
...
...
@@ -12,8 +12,8 @@ from . import FairseqDecoder
class
FairseqIncrementalDecoder
(
FairseqDecoder
):
"""Base class for incremental decoders."""
def
__init__
(
self
):
super
().
__init__
()
def
__init__
(
self
,
dictionary
):
super
().
__init__
(
dictionary
)
self
.
_is_incremental_eval
=
False
self
.
_incremental_state
=
{}
...
...
@@ -37,7 +37,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
with model.decoder.incremental_inference():
for step in range(maxlen):
out, _ = model.decoder(tokens[:, :step], encoder_out)
probs =
torch.nn.functional.log_softmax(out[:, -1, :]
)
probs =
model.get_normalized_probs(out[:, -1, :], log_probs=False
)
```
"""
class
IncrementalInference
(
object
):
...
...
@@ -86,6 +86,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
beam_size is required if using BeamableMM.
"""
if
self
.
_is_incremental_eval
:
del
self
.
_incremental_state
self
.
_incremental_state
=
{}
def
apply_clear_incremental_state
(
module
):
...
...
@@ -110,7 +111,9 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def
set_beam_size
(
self
,
beam_size
):
"""Sets the beam size in the decoder and all children."""
def
apply_set_beam_size
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'set_beam_size'
):
module
.
set_beam_size
(
beam_size
)
self
.
apply
(
apply_set_beam_size
)
if
getattr
(
self
,
'_beam_size'
,
-
1
)
!=
beam_size
:
def
apply_set_beam_size
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'set_beam_size'
):
module
.
set_beam_size
(
beam_size
)
self
.
apply
(
apply_set_beam_size
)
self
.
_beam_size
=
beam_size
fairseq/models/fairseq_model.py
View file @
d9f46c54
...
...
@@ -35,6 +35,10 @@ class FairseqModel(nn.Module):
decoder_out
,
_
=
self
.
decoder
(
input_tokens
,
encoder_out
)
return
decoder_out
.
view
(
-
1
,
decoder_out
.
size
(
-
1
))
def
get_normalized_probs
(
self
,
net_output
,
log_probs
):
"""Get normalized probabilities (or log probs) from a net's output."""
return
self
.
decoder
.
get_normalized_probs
(
net_output
,
log_probs
)
def
max_encoder_positions
(
self
):
"""Maximum input length supported by the encoder."""
return
self
.
encoder
.
max_positions
()
...
...
@@ -62,6 +66,11 @@ class FairseqModel(nn.Module):
return
self
.
apply
(
apply_remove_weight_norm
)
def
apply_make_generation_fast_
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'make_generation_fast_'
):
module
.
make_generation_fast_
(
**
kwargs
)
self
.
apply
(
apply_make_generation_fast_
)
def
train
(
mode
):
if
mode
:
raise
RuntimeError
(
'cannot train after make_generation_fast'
)
...
...
@@ -69,8 +78,3 @@ class FairseqModel(nn.Module):
# this model should no longer be used for training
self
.
eval
()
self
.
train
=
train
def
apply_make_generation_fast_
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'make_generation_fast_'
):
module
.
make_generation_fast_
(
**
kwargs
)
self
.
apply
(
apply_make_generation_fast_
)
fairseq/models/fconv.py
View file @
d9f46c54
...
...
@@ -13,26 +13,11 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
fairseq.data
import
LanguagePairDataset
from
fairseq.modules
import
BeamableMM
,
GradMultiply
,
LinearizedConvolution
from
fairseq.modules
import
BeamableMM
,
GradMultiply
,
LearnedPositionalEmbedding
,
LinearizedConvolution
from
.
import
FairseqEncoder
,
FairseqIncrementalDecoder
,
FairseqModel
def
make_positions
(
tokens
,
padding_idx
,
left_pad
,
offset
=
0
):
seqlen
=
tokens
.
size
(
1
)
if
not
hasattr
(
make_positions
,
'range'
):
make_positions
.
range
=
tokens
.
new
()
if
make_positions
.
range
.
numel
()
<
offset
+
seqlen
:
# offset positions by the padding index
torch
.
arange
(
padding_idx
+
1
,
padding_idx
+
1
+
offset
+
seqlen
,
out
=
make_positions
.
range
)
mask
=
tokens
.
ne
(
padding_idx
)
positions
=
make_positions
.
range
[
offset
:
offset
+
seqlen
].
expand_as
(
tokens
)
if
left_pad
:
positions
=
positions
-
mask
.
size
(
1
)
+
mask
.
long
().
sum
(
dim
=
1
).
unsqueeze
(
1
)
return
tokens
.
clone
().
masked_scatter_
(
mask
,
positions
[
mask
])
class
FConvModel
(
FairseqModel
):
def
__init__
(
self
,
encoder
,
decoder
):
super
().
__init__
(
encoder
,
decoder
)
...
...
@@ -43,15 +28,15 @@ class FConvEncoder(FairseqEncoder):
"""Convolutional encoder"""
def
__init__
(
self
,
dictionary
,
embed_dim
=
512
,
max_positions
=
1024
,
convolutions
=
((
512
,
3
),)
*
20
,
dropout
=
0.1
):
super
().
__init__
()
self
.
dictionary
=
dictionary
super
().
__init__
(
dictionary
)
self
.
dropout
=
dropout
self
.
num_attention_layers
=
None
num_embeddings
=
len
(
dictionary
)
padding_idx
=
dictionary
.
pad
()
self
.
embed_tokens
=
Embedding
(
num_embeddings
,
embed_dim
,
padding_idx
)
self
.
embed_positions
=
Embedding
(
max_positions
,
embed_dim
,
padding_idx
)
self
.
embed_positions
=
PositionalEmbedding
(
max_positions
,
embed_dim
,
padding_idx
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_SOURCE
)
in_channels
=
convolutions
[
0
][
0
]
self
.
fc1
=
Linear
(
embed_dim
,
in_channels
,
dropout
=
dropout
)
...
...
@@ -68,11 +53,8 @@ class FConvEncoder(FairseqEncoder):
self
.
fc2
=
Linear
(
in_channels
,
embed_dim
)
def
forward
(
self
,
src_tokens
):
positions
=
Variable
(
make_positions
(
src_tokens
.
data
,
self
.
dictionary
.
pad
(),
left_pad
=
LanguagePairDataset
.
LEFT_PAD_SOURCE
))
# embed tokens and positions
x
=
self
.
embed_tokens
(
src_tokens
)
+
self
.
embed_positions
(
positio
ns
)
x
=
self
.
embed_tokens
(
src_tokens
)
+
self
.
embed_positions
(
src_toke
ns
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
input_embedding
=
x
...
...
@@ -87,7 +69,7 @@ class FConvEncoder(FairseqEncoder):
residual
=
x
if
proj
is
None
else
proj
(
x
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
conv
(
x
)
x
=
F
.
glu
(
x
,
dim
=
-
1
)
x
=
F
.
glu
(
x
,
dim
=
2
)
x
=
(
x
+
residual
)
*
math
.
sqrt
(
0.5
)
# T x B x C -> B x T x C
...
...
@@ -106,7 +88,7 @@ class FConvEncoder(FairseqEncoder):
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
return
self
.
embed_positions
.
num_embeddings
-
self
.
dictionary
.
pad
()
-
1
return
self
.
embed_positions
.
max_positions
()
class
AttentionLayer
(
nn
.
Module
):
...
...
@@ -128,7 +110,7 @@ class AttentionLayer(nn.Module):
# softmax over last dim
sz
=
x
.
size
()
x
=
F
.
softmax
(
x
.
view
(
sz
[
0
]
*
sz
[
1
],
sz
[
2
]))
x
=
F
.
softmax
(
x
.
view
(
sz
[
0
]
*
sz
[
1
],
sz
[
2
])
,
dim
=
1
)
x
=
x
.
view
(
sz
)
attn_scores
=
x
...
...
@@ -145,28 +127,32 @@ class AttentionLayer(nn.Module):
def
make_generation_fast_
(
self
,
beamable_mm_beam_size
=
None
,
**
kwargs
):
"""Replace torch.bmm with BeamableMM."""
if
beamable_mm_beam_size
is
not
None
:
self
.
bmm
=
BeamableMM
(
beamable_mm_beam_size
)
del
self
.
bmm
self
.
add_module
(
'bmm'
,
BeamableMM
(
beamable_mm_beam_size
))
class
FConvDecoder
(
FairseqIncrementalDecoder
):
"""Convolutional decoder"""
def
__init__
(
self
,
dictionary
,
embed_dim
=
512
,
out_embed_dim
=
256
,
max_positions
=
1024
,
convolutions
=
((
512
,
3
),)
*
20
,
attention
=
True
,
dropout
=
0.1
):
super
().
__init__
()
attention
=
True
,
dropout
=
0.1
,
share_embed
=
False
):
super
().
__init__
(
dictionary
)
self
.
register_buffer
(
'version'
,
torch
.
Tensor
([
2
]))
self
.
dictionary
=
dictionary
self
.
dropout
=
dropout
in_channels
=
convolutions
[
0
][
0
]
if
isinstance
(
attention
,
bool
):
# expand True into [True, True, ...] and do the same with False
attention
=
[
attention
]
*
len
(
convolutions
)
if
not
isinstance
(
attention
,
list
)
or
len
(
attention
)
!=
len
(
convolutions
):
raise
ValueError
(
'Attention is expected to be a list of booleans of '
'length equal to the number of layers.'
)
num_embeddings
=
len
(
dictionary
)
padding_idx
=
dictionary
.
pad
()
self
.
embed_tokens
=
Embedding
(
num_embeddings
,
embed_dim
,
padding_idx
)
self
.
embed_positions
=
Embedding
(
max_positions
,
embed_dim
,
padding_idx
)
self
.
embed_positions
=
PositionalEmbedding
(
max_positions
,
embed_dim
,
padding_idx
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
)
self
.
fc1
=
Linear
(
embed_dim
,
in_channels
,
dropout
=
dropout
)
self
.
projections
=
nn
.
ModuleList
()
...
...
@@ -183,35 +169,28 @@ class FConvDecoder(FairseqIncrementalDecoder):
if
attention
[
i
]
else
None
)
in_channels
=
out_channels
self
.
fc2
=
Linear
(
in_channels
,
out_embed_dim
)
self
.
fc3
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout
)
def
forward
(
self
,
input_tokens
,
encoder_out
):
if
self
.
_is_incremental_eval
:
return
self
.
incremental_forward
(
input_tokens
,
encoder_out
)
if
share_embed
:
assert
out_embed_dim
==
embed_dim
,
\
"Shared embed weights implies same dimensions "
\
" out_embed_dim={} vs embed_dim={}"
.
format
(
out_embed_dim
,
embed_dim
)
self
.
fc3
=
nn
.
Linear
(
out_embed_dim
,
num_embeddings
)
self
.
fc3
.
weight
=
self
.
embed_tokens
.
weight
else
:
return
self
.
batch_forward
(
input_tokens
,
encoder_out
)
def
batch_forward
(
self
,
input_tokens
,
encoder_out
):
"""Forward pass for decoding multiple time steps in batch mode."""
positions
=
Variable
(
make_positions
(
input_tokens
.
data
,
self
.
dictionary
.
pad
(),
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
))
return
self
.
_forward
(
input_tokens
,
positions
,
encoder_out
)
def
incremental_forward
(
self
,
input_tokens
,
encoder_out
):
"""Forward pass for one time step."""
# positions is the same for every token when decoding a single step
positions
=
Variable
(
input_tokens
.
data
.
new
(
1
,
1
).
fill_
(
self
.
dictionary
.
pad
()
+
input_tokens
.
size
(
1
)))
self
.
fc3
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout
)
# keep only the last token for incremental forward pass
return
self
.
_forward
(
input_tokens
[:,
-
1
:],
positions
,
encoder_out
)
def
_forward
(
self
,
input_tokens
,
positions
,
encoder_out
):
def
forward
(
self
,
input_tokens
,
encoder_out
):
# split and transpose encoder outputs
encoder_a
,
encoder_b
=
self
.
_split_encoder_out
(
encoder_out
)
# embed positions
positions
=
self
.
embed_positions
(
input_tokens
)
if
self
.
_is_incremental_eval
:
# keep only the last token for incremental forward pass
input_tokens
=
input_tokens
[:,
-
1
:]
# embed tokens and positions
x
=
self
.
embed_tokens
(
input_tokens
)
+
self
.
embed_positions
(
positions
)
x
=
self
.
embed_tokens
(
input_tokens
)
+
positions
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
target_embedding
=
x
...
...
@@ -230,7 +209,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
conv
(
x
)
x
=
conv
.
remove_future_timesteps
(
x
)
x
=
F
.
glu
(
x
)
x
=
F
.
glu
(
x
,
dim
=
2
)
# attention
if
attention
is
not
None
:
...
...
@@ -264,7 +243,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
def
max_positions
(
self
):
"""Maximum output length supported by the decoder."""
return
self
.
embed_positions
.
num_embeddings
-
self
.
dictionary
.
pad
()
-
1
return
self
.
embed_positions
.
max_positions
()
def
upgrade_state_dict
(
self
,
state_dict
):
if
state_dict
.
get
(
'decoder.version'
,
torch
.
Tensor
([
1
]))[
0
]
<
2
:
...
...
@@ -304,6 +283,12 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
return
m
def
PositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
):
m
=
LearnedPositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
)
m
.
weight
.
data
.
normal_
(
0
,
0.1
)
return
m
def
Linear
(
in_features
,
out_features
,
dropout
=
0
):
"""Weight-normalized Linear layer (input: N x T x C)"""
m
=
nn
.
Linear
(
in_features
,
out_features
)
...
...
@@ -394,6 +379,7 @@ def parse_arch(args):
args
.
decoder_layers
=
getattr
(
args
,
'decoder_layers'
,
'[(512, 3)] * 20'
)
args
.
decoder_out_embed_dim
=
getattr
(
args
,
'decoder_out_embed_dim'
,
256
)
args
.
decoder_attention
=
getattr
(
args
,
'decoder_attention'
,
'True'
)
args
.
share_input_output_embed
=
getattr
(
args
,
'share_input_output_embed'
,
False
)
return
args
...
...
@@ -413,5 +399,6 @@ def build_model(args, src_dict, dst_dict):
attention
=
eval
(
args
.
decoder_attention
),
dropout
=
args
.
dropout
,
max_positions
=
args
.
max_target_positions
,
share_embed
=
args
.
share_input_output_embed
)
return
FConvModel
(
encoder
,
decoder
)
fairseq/models/lstm.py
View file @
d9f46c54
...
...
@@ -23,8 +23,7 @@ class LSTMEncoder(FairseqEncoder):
"""LSTM encoder."""
def
__init__
(
self
,
dictionary
,
embed_dim
=
512
,
num_layers
=
1
,
dropout_in
=
0.1
,
dropout_out
=
0.1
):
super
().
__init__
()
self
.
dictionary
=
dictionary
super
().
__init__
(
dictionary
)
self
.
dropout_in
=
dropout_in
self
.
dropout_out
=
dropout_out
...
...
@@ -94,7 +93,7 @@ class AttentionLayer(nn.Module):
# compute attention
attn_scores
=
(
source_hids
*
x
.
unsqueeze
(
0
)).
sum
(
dim
=
2
)
attn_scores
=
F
.
softmax
(
attn_scores
.
t
()).
t
()
# srclen x bsz
attn_scores
=
F
.
softmax
(
attn_scores
.
t
()
,
dim
=
1
).
t
()
# srclen x bsz
# sum weighted sources
x
=
(
attn_scores
.
unsqueeze
(
2
)
*
source_hids
).
sum
(
dim
=
0
)
...
...
@@ -108,8 +107,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
def
__init__
(
self
,
dictionary
,
encoder_embed_dim
=
512
,
embed_dim
=
512
,
out_embed_dim
=
512
,
num_layers
=
1
,
dropout_in
=
0.1
,
dropout_out
=
0.1
,
attention
=
True
):
super
().
__init__
()
self
.
dictionary
=
dictionary
super
().
__init__
(
dictionary
)
self
.
dropout_in
=
dropout_in
self
.
dropout_out
=
dropout_out
...
...
fairseq/modules/__init__.py
View file @
d9f46c54
...
...
@@ -9,11 +9,13 @@
from
.beamable_mm
import
BeamableMM
from
.conv_tbc
import
ConvTBC
from
.grad_multiply
import
GradMultiply
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.linearized_convolution
import
LinearizedConvolution
__all__
=
[
'BeamableMM'
,
'ConvTBC'
,
'GradMultiply'
,
'LearnedPositionalEmbedding'
,
'LinearizedConvolution'
,
]
fairseq/modules/conv_tbc.py
View file @
d9f46c54
...
...
@@ -7,9 +7,11 @@
#
import
torch
from
torch.autograd
import
Variable
,
Function
from
torch.autograd
import
Function
from
torch.nn.modules.utils
import
_single
from
fairseq
import
utils
try
:
from
fairseq
import
temporal_convolution_tbc
except
ImportError
as
e
:
...
...
@@ -93,9 +95,9 @@ class ConvTBCFunction(Function):
input
,
weight
)
grad_input
=
V
ariable
(
grad_input
,
volatile
=
True
)
grad_weight
=
V
ariable
(
grad_weight
,
volatile
=
True
)
grad_bias
=
V
ariable
(
grad_bias
,
volatile
=
True
)
grad_input
=
utils
.
volatile_v
ariable
(
grad_input
)
grad_weight
=
utils
.
volatile_v
ariable
(
grad_weight
)
grad_bias
=
utils
.
volatile_v
ariable
(
grad_bias
)
return
grad_input
,
grad_weight
,
grad_bias
,
None
...
...
fairseq/modules/learned_positional_embedding.py
0 → 100644
View file @
d9f46c54
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import
torch
from
torch.autograd
import
Variable
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
LearnedPositionalEmbedding
(
nn
.
Embedding
):
"""This module learns positional embeddings up to a fixed maximum size.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
):
super
().
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
)
self
.
left_pad
=
left_pad
self
.
_is_incremental_eval
=
False
def
incremental_eval
(
self
,
mode
=
True
):
self
.
_is_incremental_eval
=
mode
def
forward
(
self
,
input
):
"""Input is expected to be of size [bsz x seqlen]."""
if
self
.
_is_incremental_eval
:
# positions is the same for every token when decoding a single step
positions
=
Variable
(
input
.
data
.
new
(
1
,
1
).
fill_
(
self
.
padding_idx
+
input
.
size
(
1
)))
else
:
positions
=
Variable
(
self
.
make_positions
(
input
.
data
))
return
super
().
forward
(
positions
)
def
max_positions
(
self
):
"""Maximum number of supported positions."""
return
self
.
num_embeddings
-
self
.
padding_idx
-
1
def
make_positions
(
self
,
input
):
"""Replace non-padding symbols with their position numbers."""
if
not
hasattr
(
self
,
'range_buf'
):
self
.
range_buf
=
input
.
new
()
seqlen
=
input
.
size
(
1
)
if
self
.
range_buf
.
numel
()
<
seqlen
:
# offset positions by the padding index
torch
.
arange
(
self
.
padding_idx
+
1
,
self
.
padding_idx
+
1
+
seqlen
,
out
=
self
.
range_buf
)
mask
=
input
.
ne
(
self
.
padding_idx
)
positions
=
self
.
range_buf
[:
seqlen
].
expand_as
(
input
)
if
self
.
left_pad
:
positions
=
positions
-
mask
.
size
(
1
)
+
mask
.
long
().
sum
(
dim
=
1
).
unsqueeze
(
1
)
return
input
.
clone
().
masked_scatter_
(
mask
,
positions
[
mask
])
fairseq/modules/linearized_convolution.py
View file @
d9f46c54
...
...
@@ -8,6 +8,9 @@
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
.conv_tbc
import
ConvTBC
...
...
@@ -65,8 +68,9 @@ class LinearizedConvolution(ConvTBC):
self
.
input_buffer
[:,
:
-
1
,
:]
=
self
.
input_buffer
[:,
1
:,
:].
clone
()
# append next input
self
.
input_buffer
[:,
-
1
,
:]
=
input
[:,
-
1
,
:]
input
=
torch
.
autograd
.
Variable
(
self
.
input_buffer
,
volatile
=
True
)
output
=
F
.
linear
(
input
.
view
(
bsz
,
-
1
),
weight
,
self
.
bias
)
input
=
utils
.
volatile_variable
(
self
.
input_buffer
)
with
utils
.
maybe_no_grad
():
output
=
F
.
linear
(
input
.
view
(
bsz
,
-
1
),
weight
,
self
.
bias
)
return
output
.
view
(
bsz
,
1
,
-
1
)
def
clear_incremental_state
(
self
):
...
...
fairseq/multiprocessing_trainer.py
View file @
d9f46c54
...
...
@@ -15,9 +15,10 @@ import math
import
torch
from
torch.optim.lr_scheduler
import
LambdaLR
,
ReduceLROnPlateau
from
fairseq
import
meters
,
nccl
,
utils
from
fairseq
import
nccl
,
utils
from
fairseq.multiprocessing_event_loop
import
MultiprocessingEventLoop
,
Future
from
fairseq.nag
import
NAG
from
fairseq.optim.nag
import
NAG
from
fairseq.optim.adam
import
Adam
class
MultiprocessingTrainer
(
MultiprocessingEventLoop
):
...
...
@@ -95,7 +96,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
'betas'
:
eval
(
self
.
args
.
adam_betas
),
'weight_decay'
:
self
.
args
.
weight_decay
,
}
return
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
**
self
.
_override_optim_state
)
return
Adam
(
self
.
model
.
parameters
(),
**
self
.
_override_optim_state
)
elif
self
.
args
.
optimizer
==
'nag'
:
self
.
_override_optim_state
=
{
'lr'
:
self
.
args
.
lr
[
0
],
...
...
@@ -116,6 +117,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def
_build_lr_scheduler
(
self
):
if
len
(
self
.
args
.
lr
)
>
1
or
self
.
args
.
force_anneal
>
0
:
lrs
=
self
.
args
.
lr
def
anneal
(
e
):
if
e
<
self
.
args
.
force_anneal
:
# use fixed LR schedule
...
...
@@ -123,6 +125,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
else
:
next_lr
=
lrs
[
-
1
]
*
self
.
args
.
lrshrink
**
(
e
+
1
-
self
.
args
.
force_anneal
)
return
next_lr
/
lrs
[
0
]
# correct for scaling from LambdaLR
lr_scheduler
=
LambdaLR
(
self
.
optimizer
,
anneal
)
lr_scheduler
.
best
=
None
else
:
...
...
@@ -225,20 +228,21 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
model
.
train
()
self
.
optimizer
.
zero_grad
()
sample_size
,
logging_output
,
oom
=
0
,
{},
False
if
self
.
_sample
is
not
None
:
try
:
# calculate loss and sample size
self
.
loss
,
sample_size
,
logging_output
=
self
.
criterion
(
self
.
model
,
self
.
_sample
)
except
RuntimeError
as
e
:
if
not
eval
and
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory on GPU #{}, skipping batch'
.
format
(
device_id
))
oom
=
True
self
.
loss
=
None
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
torch
.
cuda
.
empty_cache
()
else
:
raise
e
with
utils
.
maybe_no_grad
(
eval
):
sample_size
,
logging_output
,
oom
=
0
,
{},
False
if
self
.
_sample
is
not
None
:
try
:
# calculate loss and sample size
self
.
loss
,
sample_size
,
logging_output
=
self
.
criterion
(
self
.
model
,
self
.
_sample
)
except
RuntimeError
as
e
:
if
not
eval
and
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory on GPU #{}, skipping batch'
.
format
(
device_id
))
oom
=
True
self
.
loss
=
None
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
torch
.
cuda
.
empty_cache
()
else
:
raise
e
return
sample_size
,
logging_output
,
oom
...
...
@@ -262,7 +266,10 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
_all_reduce_and_rescale_grads
(
grad_denom
)
# clip grads
grad_norm
=
torch
.
nn
.
utils
.
clip_grad_norm
(
self
.
model
.
parameters
(),
self
.
args
.
clip_norm
)
if
self
.
args
.
clip_norm
>
0
:
grad_norm
=
torch
.
nn
.
utils
.
clip_grad_norm
(
self
.
model
.
parameters
(),
self
.
args
.
clip_norm
)
else
:
grad_norm
=
math
.
sqrt
(
sum
([
p
.
grad
.
data
.
norm
()
**
2
for
p
in
self
.
model
.
parameters
()]))
# take an optimization step
self
.
optimizer
.
step
()
...
...
@@ -378,4 +385,4 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
_max_bsz_seen
=
sample
[
'target'
].
size
(
0
)
torch
.
cuda
.
empty_cache
()
self
.
_sample
=
utils
.
prepare_samp
le
(
sample
,
volatile
=
volatile
,
cuda_device
=
device_id
)
self
.
_sample
=
utils
.
make_variab
le
(
sample
,
volatile
=
volatile
,
cuda_device
=
device_id
)
fairseq/optim/adam.py
0 → 100644
View file @
d9f46c54
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import
math
import
torch
from
torch.optim.optimizer
import
Optimizer
class
Adam
(
Optimizer
):
"""Implements Adam algorithm.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
):
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
)
super
(
Adam
,
self
).
__init__
(
params
,
defaults
)
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
grad
=
p
.
grad
.
data
if
grad
.
is_sparse
:
raise
RuntimeError
(
'Adam does not support sparse gradients, please consider SparseAdam instead'
)
amsgrad
=
group
[
'amsgrad'
]
state
=
self
.
state
[
p
]
# State initialization
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
# Exponential moving average of gradient values
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
p
.
data
)
# Exponential moving average of squared gradient values
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
if
amsgrad
:
# Maintains max of all exp. moving avg. of sq. grad. values
state
[
'max_exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
exp_avg
,
exp_avg_sq
=
state
[
'exp_avg'
],
state
[
'exp_avg_sq'
]
if
amsgrad
:
max_exp_avg_sq
=
state
[
'max_exp_avg_sq'
]
beta1
,
beta2
=
group
[
'betas'
]
state
[
'step'
]
+=
1
# Decay the first and second moment running average coefficient
exp_avg
.
mul_
(
beta1
).
add_
(
1
-
beta1
,
grad
)
exp_avg_sq
.
mul_
(
beta2
).
addcmul_
(
1
-
beta2
,
grad
,
grad
)
if
amsgrad
:
# Maintains the maximum of all 2nd moment running avg. till now
torch
.
max
(
max_exp_avg_sq
,
exp_avg_sq
,
out
=
max_exp_avg_sq
)
# Use the max. for normalizing running avg. of gradient
denom
=
max_exp_avg_sq
.
sqrt
().
add_
(
group
[
'eps'
])
else
:
denom
=
exp_avg_sq
.
sqrt
().
add_
(
group
[
'eps'
])
bias_correction1
=
1
-
beta1
**
state
[
'step'
]
bias_correction2
=
1
-
beta2
**
state
[
'step'
]
step_size
=
group
[
'lr'
]
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
if
group
[
'weight_decay'
]
!=
0
:
p
.
data
.
add_
(
-
group
[
'weight_decay'
],
p
.
data
)
p
.
data
.
addcdiv_
(
-
step_size
,
exp_avg
,
denom
)
return
loss
fairseq/nag.py
→
fairseq/
optim/
nag.py
View file @
d9f46c54
...
...
@@ -11,7 +11,7 @@ from torch.optim.optimizer import Optimizer, required
class
NAG
(
Optimizer
):
def
__init__
(
self
,
params
,
lr
=
required
,
momentum
=
0
,
weight_decay
=
0
):
defaults
=
dict
(
lr
=
lr
,
momentum
=
momentum
,
weight_decay
=
weight_decay
)
defaults
=
dict
(
lr
=
lr
,
lr_old
=
lr
,
momentum
=
momentum
,
weight_decay
=
weight_decay
)
super
(
NAG
,
self
).
__init__
(
params
,
defaults
)
def
step
(
self
,
closure
=
None
):
...
...
@@ -29,6 +29,8 @@ class NAG(Optimizer):
weight_decay
=
group
[
'weight_decay'
]
momentum
=
group
[
'momentum'
]
lr
=
group
[
'lr'
]
lr_old
=
group
.
get
(
'lr_old'
,
lr
)
lr_correct
=
lr
/
lr_old
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
...
...
@@ -43,9 +45,11 @@ class NAG(Optimizer):
if
weight_decay
!=
0
:
p
.
data
.
mul_
(
1
-
weight_decay
)
p
.
data
.
add_
(
momentum
*
momentum
,
buf
)
p
.
data
.
add_
(
momentum
*
momentum
*
lr_correct
,
buf
)
p
.
data
.
add_
(
-
(
1
+
momentum
)
*
lr
,
d_p
)
buf
.
mul_
(
momentum
).
add_
(
-
lr
,
d_p
)
buf
.
mul_
(
momentum
*
lr_correct
).
add_
(
-
lr
,
d_p
)
group
[
'lr_old'
]
=
lr
return
loss
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment