Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
f69206c8
"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "bdeff4d64a57e556c2b62f887da03a2c37c54d54"
Commit
f69206c8
authored
Jul 26, 2018
by
Alexei Baevski
Committed by
Myle Ott
Sep 03, 2018
Browse files
fix adaptive softmax indexing
parent
af38ed48
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
7 deletions
+26
-7
fairseq/models/fairseq_model.py
fairseq/models/fairseq_model.py
+13
-5
fairseq/modules/adaptive_softmax.py
fairseq/modules/adaptive_softmax.py
+13
-2
No files found.
fairseq/models/fairseq_model.py
View file @
f69206c8
...
@@ -57,11 +57,19 @@ class BaseFairseqModel(nn.Module):
...
@@ -57,11 +57,19 @@ class BaseFairseqModel(nn.Module):
def
upgrade_state_dict
(
self
,
state_dict
):
def
upgrade_state_dict
(
self
,
state_dict
):
assert
state_dict
is
not
None
assert
state_dict
is
not
None
def
do_upgrade
(
m
):
def
do_upgrade
(
m
,
prefix
):
if
m
!=
self
and
hasattr
(
m
,
'upgrade_state_dict'
):
if
len
(
prefix
)
>
0
:
m
.
upgrade_state_dict
(
state_dict
)
prefix
+=
'.'
self
.
apply
(
do_upgrade
)
for
n
,
c
in
m
.
named_children
():
name
=
prefix
+
n
if
hasattr
(
c
,
'upgrade_state_dict_named'
):
c
.
upgrade_state_dict_named
(
state_dict
,
name
)
elif
hasattr
(
c
,
'upgrade_state_dict'
):
c
.
upgrade_state_dict
(
state_dict
)
do_upgrade
(
c
,
name
)
do_upgrade
(
self
,
''
)
def
make_generation_fast_
(
self
,
**
kwargs
):
def
make_generation_fast_
(
self
,
**
kwargs
):
"""Optimize model for faster generation."""
"""Optimize model for faster generation."""
...
...
fairseq/modules/adaptive_softmax.py
View file @
f69206c8
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
...
@@ -51,6 +52,16 @@ class AdaptiveSoftmax(nn.Module):
...
@@ -51,6 +52,16 @@ class AdaptiveSoftmax(nn.Module):
self
.
apply
(
init_weights
)
self
.
apply
(
init_weights
)
self
.
register_buffer
(
'version'
,
torch
.
LongTensor
([
1
]))
# versions prior to 1 had a bug that offset indices on the head by 1
self
.
buggy_offset
=
0
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
version_name
=
name
+
'.version'
if
version_name
not
in
state_dict
:
self
.
buggy_offset
=
1
state_dict
[
version_name
]
=
torch
.
LongTensor
([
1
])
def
adapt_target
(
self
,
target
):
def
adapt_target
(
self
,
target
):
"""
"""
In order to be efficient, the AdaptiveSoftMax does not compute the
In order to be efficient, the AdaptiveSoftMax does not compute the
...
@@ -65,7 +76,7 @@ class AdaptiveSoftmax(nn.Module):
...
@@ -65,7 +76,7 @@ class AdaptiveSoftmax(nn.Module):
for
i
in
range
(
len
(
self
.
cutoff
)
-
1
):
for
i
in
range
(
len
(
self
.
cutoff
)
-
1
):
mask
=
target
.
ge
(
self
.
cutoff
[
i
]).
mul
(
target
.
lt
(
self
.
cutoff
[
i
+
1
]))
mask
=
target
.
ge
(
self
.
cutoff
[
i
]).
mul
(
target
.
lt
(
self
.
cutoff
[
i
+
1
]))
new_target
[
0
][
mask
]
=
self
.
cutoff
[
0
]
+
i
-
1
new_target
[
0
][
mask
]
=
self
.
cutoff
[
0
]
+
i
-
self
.
buggy_offset
if
mask
.
any
():
if
mask
.
any
():
target_idxs
.
append
(
mask
.
nonzero
().
squeeze
(
1
))
target_idxs
.
append
(
mask
.
nonzero
().
squeeze
(
1
))
...
@@ -118,7 +129,7 @@ class AdaptiveSoftmax(nn.Module):
...
@@ -118,7 +129,7 @@ class AdaptiveSoftmax(nn.Module):
head_sz
=
self
.
cutoff
[
0
]
+
len
(
self
.
tail
)
head_sz
=
self
.
cutoff
[
0
]
+
len
(
self
.
tail
)
log_probs
[:,
:
head_sz
]
=
self
.
lsm
(
head_y
)
log_probs
[:,
:
head_sz
]
=
self
.
lsm
(
head_y
)
tail_priors
=
log_probs
[:,
self
.
cutoff
[
0
]
-
1
:
head_sz
-
1
].
clone
()
tail_priors
=
log_probs
[:,
self
.
cutoff
[
0
]
-
self
.
buggy_offset
:
head_sz
-
self
.
buggy_offset
].
clone
()
for
i
in
range
(
len
(
self
.
tail
)):
for
i
in
range
(
len
(
self
.
tail
)):
start
=
self
.
cutoff
[
i
]
start
=
self
.
cutoff
[
i
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment