Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
b9956a6a
Commit
b9956a6a
authored
Aug 27, 2018
by
Myle Ott
Browse files
Fix FP16 version comparison
parent
753935ef
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
3 deletions
+3
-3
fairseq/models/fconv.py
fairseq/models/fconv.py
+1
-1
fairseq/models/transformer.py
fairseq/models/transformer.py
+2
-2
No files found.
fairseq/models/fconv.py
View file @
b9956a6a
...
...
@@ -497,7 +497,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
return
self
.
embed_positions
.
max_positions
()
if
self
.
embed_positions
is
not
None
else
float
(
'inf'
)
def
upgrade_state_dict
(
self
,
state_dict
):
if
state_dict
.
get
(
'decoder.version'
,
torch
.
Tensor
([
1
]))[
0
]
<
2
:
if
utils
.
item
(
state_dict
.
get
(
'decoder.version'
,
torch
.
Tensor
([
1
]))[
0
]
)
<
2
:
# old models use incorrect weight norm dimension
for
i
,
conv
in
enumerate
(
self
.
convolutions
):
# reconfigure weight norm
...
...
fairseq/models/transformer.py
View file @
b9956a6a
...
...
@@ -277,7 +277,7 @@ class TransformerEncoder(FairseqEncoder):
if
'encoder.embed_positions.weights'
in
state_dict
:
del
state_dict
[
'encoder.embed_positions.weights'
]
state_dict
[
'encoder.embed_positions._float_tensor'
]
=
torch
.
FloatTensor
(
1
)
if
state_dict
.
get
(
'encoder.version'
,
torch
.
Tensor
([
1
]))[
0
]
<
2
:
if
utils
.
item
(
state_dict
.
get
(
'encoder.version'
,
torch
.
Tensor
([
1
]))[
0
]
)
<
2
:
# earlier checkpoints did not normalize after the stack of layers
self
.
layer_norm
=
None
self
.
normalize
=
False
...
...
@@ -415,7 +415,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if
k
in
state_dict
:
state_dict
[
'decoder.layers.{}.{}.{}'
.
format
(
i
,
new
,
m
)]
=
state_dict
[
k
]
del
state_dict
[
k
]
if
state_dict
.
get
(
'decoder.version'
,
torch
.
Tensor
([
1
]))[
0
]
<
2
:
if
utils
.
item
(
state_dict
.
get
(
'decoder.version'
,
torch
.
Tensor
([
1
]))[
0
]
)
<
2
:
# earlier checkpoints did not normalize after the stack of layers
self
.
layer_norm
=
None
self
.
normalize
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment