Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
8afb7761
Commit
8afb7761
authored
Apr 24, 2018
by
Myle Ott
Browse files
Fix tests
parent
7c7634f6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
9 additions
and
5 deletions
+9
-5
fairseq/data.py
fairseq/data.py
+1
-1
fairseq/dictionary.py
fairseq/dictionary.py
+3
-2
scripts/average_checkpoints.py
scripts/average_checkpoints.py
+4
-1
tests/utils.py
tests/utils.py
+1
-1
No files found.
fairseq/data.py
View file @
8afb7761
...
@@ -442,7 +442,7 @@ def numpy_seed(seed):
...
@@ -442,7 +442,7 @@ def numpy_seed(seed):
def
get_dummy_batch
(
ntokens
,
src_dict
,
dst_dict
,
src_len
=
128
,
tgt_len
=
128
):
def
get_dummy_batch
(
ntokens
,
src_dict
,
dst_dict
,
src_len
=
128
,
tgt_len
=
128
):
bsz
=
int
(
ntokens
/
max
(
src_len
,
tgt_len
))
bsz
=
int
(
ntokens
/
max
(
src_len
,
tgt_len
))
bsz
=
(
bsz
/
/
8
)
*
8
bsz
=
math
.
ceil
(
bsz
/
8
)
*
8
assert
src_dict
.
pad
()
==
dst_dict
.
pad
()
assert
src_dict
.
pad
()
==
dst_dict
.
pad
()
pad_idx
=
src_dict
.
pad
()
pad_idx
=
src_dict
.
pad
()
src_vocab
,
dst_vocab
=
len
(
src_dict
),
len
(
dst_dict
)
src_vocab
,
dst_vocab
=
len
(
src_dict
),
len
(
dst_dict
)
...
...
fairseq/dictionary.py
View file @
8afb7761
...
@@ -93,9 +93,10 @@ class Dictionary(object):
...
@@ -93,9 +93,10 @@ class Dictionary(object):
multiple of 8, which is important on some hardware (e.g., Nvidia
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
Tensor Cores).
"""
"""
if
nwords
==
-
1
:
nwords
=
len
(
self
)
if
padding_factor
>
1
:
if
padding_factor
>
1
:
if
nwords
==
-
1
:
nwords
=
len
(
self
)
i
=
0
i
=
0
while
nwords
%
padding_factor
!=
0
:
while
nwords
%
padding_factor
!=
0
:
if
nwords
>=
len
(
self
):
if
nwords
>=
len
(
self
):
...
...
scripts/average_checkpoints.py
View file @
8afb7761
...
@@ -44,7 +44,10 @@ def average_checkpoints(inputs):
...
@@ -44,7 +44,10 @@ def average_checkpoints(inputs):
for
k
in
params_keys
:
for
k
in
params_keys
:
if
k
not
in
params_dict
:
if
k
not
in
params_dict
:
params_dict
[
k
]
=
[]
params_dict
[
k
]
=
[]
params_dict
[
k
].
append
(
model_params
[
k
].
float
())
p
=
model_params
[
k
]
if
isinstance
(
p
,
torch
.
HalfTensor
):
p
=
p
.
float
()
params_dict
[
k
].
append
(
p
)
averaged_params
=
collections
.
OrderedDict
()
averaged_params
=
collections
.
OrderedDict
()
# v should be a list of torch Tensor.
# v should be a list of torch Tensor.
...
...
tests/utils.py
View file @
8afb7761
...
@@ -21,7 +21,7 @@ def dummy_dictionary(vocab_size, prefix='token_'):
...
@@ -21,7 +21,7 @@ def dummy_dictionary(vocab_size, prefix='token_'):
for
i
in
range
(
vocab_size
):
for
i
in
range
(
vocab_size
):
token
=
prefix
+
str
(
i
)
token
=
prefix
+
str
(
i
)
d
.
add_symbol
(
token
)
d
.
add_symbol
(
token
)
d
.
finalize
(
)
d
.
finalize
(
padding_factor
=
1
)
# don't add extra padding symbols
return
d
return
d
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment