"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "58237364b1780223f48a80256f56408efe7b59a0"
Commit d85b61d6 authored by Angela Fan's avatar Angela Fan Committed by Myle Ott
Browse files

fix to adding tokens to dictionary while thresholding

parent 7f538f54
...@@ -96,14 +96,6 @@ class Dictionary(object): ...@@ -96,14 +96,6 @@ class Dictionary(object):
if nwords == -1: if nwords == -1:
nwords = len(self) nwords = len(self)
if padding_factor > 1:
i = 0
while nwords % padding_factor != 0:
if nwords >= len(self):
self.add_symbol('madeupword{:04d}'.format(i))
i += 1
nwords += 1
new_symbols = self.symbols[:self.nspecial] new_symbols = self.symbols[:self.nspecial]
new_count = self.count[:self.nspecial] new_count = self.count[:self.nspecial]
...@@ -114,8 +106,16 @@ class Dictionary(object): ...@@ -114,8 +106,16 @@ class Dictionary(object):
new_count.append(count) new_count.append(count)
else: else:
break break
threshold_nwords = len(new_symbols)
if padding_factor > 1:
i = 0
while threshold_nwords % padding_factor != 0:
new_symbols.append('madeupword{:04d}'.format(i))
i += 1
threshold_nwords += 1
assert min(new_count[self.nspecial:]) >= threshold assert min(new_count[self.nspecial:]) >= threshold
assert len(new_symbols) <= nwords
assert len(new_symbols) % padding_factor == 0 assert len(new_symbols) % padding_factor == 0
self.count = tuple(new_count) self.count = tuple(new_count)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment