Unverified Commit b48faae3 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Marian Conversion] Fix eos_token_id conversion in conversion script (#14320)

parent c016dbdb
...@@ -455,7 +455,7 @@ BART_CONVERTER = { # for each encoder and decoder layer ...@@ -455,7 +455,7 @@ BART_CONVERTER = { # for each encoder and decoder layer
class OpusState: class OpusState:
def __init__(self, source_dir): def __init__(self, source_dir, eos_token_id=0):
npz_path = find_model_file(source_dir) npz_path = find_model_file(source_dir)
self.state_dict = np.load(npz_path) self.state_dict = np.load(npz_path)
cfg = load_config_from_state_dict(self.state_dict) cfg = load_config_from_state_dict(self.state_dict)
...@@ -492,7 +492,8 @@ class OpusState: ...@@ -492,7 +492,8 @@ class OpusState:
d_model=cfg["dim-emb"], d_model=cfg["dim-emb"],
activation_function=cfg["transformer-aan-activation"], activation_function=cfg["transformer-aan-activation"],
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
eos_token_id=0, eos_token_id=eos_token_id,
forced_eos_token_id=eos_token_id,
bos_token_id=0, bos_token_id=0,
max_position_embeddings=cfg["dim-emb"], max_position_embeddings=cfg["dim-emb"],
scale_embedding=True, scale_embedding=True,
...@@ -595,7 +596,11 @@ def convert(source_dir: Path, dest_dir): ...@@ -595,7 +596,11 @@ def convert(source_dir: Path, dest_dir):
tokenizer = MarianTokenizer.from_pretrained(str(source_dir)) tokenizer = MarianTokenizer.from_pretrained(str(source_dir))
tokenizer.save_pretrained(dest_dir) tokenizer.save_pretrained(dest_dir)
opus_state = OpusState(source_dir) # retrieve EOS token and set correctly
tokenizer_has_eos_token_id = hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None
eos_token_id = tokenizer.eos_token_id if tokenizer_has_eos_token_id else 0
opus_state = OpusState(source_dir, eos_token_id=eos_token_id)
if opus_state.cfg["vocab_size"] != len(tokenizer.encoder): if opus_state.cfg["vocab_size"] != len(tokenizer.encoder):
raise ValueError( raise ValueError(
f"Original vocab size {opus_state.cfg['vocab_size']} and new vocab size {len(tokenizer.encoder)} mismatched" f"Original vocab size {opus_state.cfg['vocab_size']} and new vocab size {len(tokenizer.encoder)} mismatched"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment