Unverified Commit b48faae3 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Marian Conversion] Fix eos_token_id conversion in conversion script (#14320)

parent c016dbdb
......@@ -455,7 +455,7 @@ BART_CONVERTER = { # for each encoder and decoder layer
class OpusState:
def __init__(self, source_dir):
def __init__(self, source_dir, eos_token_id=0):
npz_path = find_model_file(source_dir)
self.state_dict = np.load(npz_path)
cfg = load_config_from_state_dict(self.state_dict)
......@@ -492,7 +492,8 @@ class OpusState:
d_model=cfg["dim-emb"],
activation_function=cfg["transformer-aan-activation"],
pad_token_id=self.pad_token_id,
eos_token_id=0,
eos_token_id=eos_token_id,
forced_eos_token_id=eos_token_id,
bos_token_id=0,
max_position_embeddings=cfg["dim-emb"],
scale_embedding=True,
......@@ -595,7 +596,11 @@ def convert(source_dir: Path, dest_dir):
tokenizer = MarianTokenizer.from_pretrained(str(source_dir))
tokenizer.save_pretrained(dest_dir)
opus_state = OpusState(source_dir)
# retrieve EOS token and set correctly
tokenizer_has_eos_token_id = hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None
eos_token_id = tokenizer.eos_token_id if tokenizer_has_eos_token_id else 0
opus_state = OpusState(source_dir, eos_token_id=eos_token_id)
if opus_state.cfg["vocab_size"] != len(tokenizer.encoder):
raise ValueError(
f"Original vocab size {opus_state.cfg['vocab_size']} and new vocab size {len(tokenizer.encoder)} mismatched"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment