Commit 8afb7761 authored by Myle Ott's avatar Myle Ott
Browse files

Fix tests

parent 7c7634f6
...@@ -442,7 +442,7 @@ def numpy_seed(seed): ...@@ -442,7 +442,7 @@ def numpy_seed(seed):
def get_dummy_batch(ntokens, src_dict, dst_dict, src_len=128, tgt_len=128): def get_dummy_batch(ntokens, src_dict, dst_dict, src_len=128, tgt_len=128):
bsz = int(ntokens / max(src_len, tgt_len)) bsz = int(ntokens / max(src_len, tgt_len))
bsz = (bsz // 8) * 8 bsz = math.ceil(bsz / 8) * 8
assert src_dict.pad() == dst_dict.pad() assert src_dict.pad() == dst_dict.pad()
pad_idx = src_dict.pad() pad_idx = src_dict.pad()
src_vocab, dst_vocab = len(src_dict), len(dst_dict) src_vocab, dst_vocab = len(src_dict), len(dst_dict)
......
...@@ -93,9 +93,10 @@ class Dictionary(object): ...@@ -93,9 +93,10 @@ class Dictionary(object):
multiple of 8, which is important on some hardware (e.g., Nvidia multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores). Tensor Cores).
""" """
if nwords == -1:
nwords = len(self)
if padding_factor > 1: if padding_factor > 1:
if nwords == -1:
nwords = len(self)
i = 0 i = 0
while nwords % padding_factor != 0: while nwords % padding_factor != 0:
if nwords >= len(self): if nwords >= len(self):
......
...@@ -44,7 +44,10 @@ def average_checkpoints(inputs): ...@@ -44,7 +44,10 @@ def average_checkpoints(inputs):
for k in params_keys: for k in params_keys:
if k not in params_dict: if k not in params_dict:
params_dict[k] = [] params_dict[k] = []
params_dict[k].append(model_params[k].float()) p = model_params[k]
if isinstance(p, torch.HalfTensor):
p = p.float()
params_dict[k].append(p)
averaged_params = collections.OrderedDict() averaged_params = collections.OrderedDict()
# v should be a list of torch Tensor. # v should be a list of torch Tensor.
......
...@@ -21,7 +21,7 @@ def dummy_dictionary(vocab_size, prefix='token_'): ...@@ -21,7 +21,7 @@ def dummy_dictionary(vocab_size, prefix='token_'):
for i in range(vocab_size): for i in range(vocab_size):
token = prefix + str(i) token = prefix + str(i)
d.add_symbol(token) d.add_symbol(token)
d.finalize() d.finalize(padding_factor=1) # don't add extra padding symbols
return d return d
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment