Unverified Commit c7ccb2e7 authored by Baizhou Huang's avatar Baizhou Huang Committed by GitHub
Browse files

Fix assertion in models (#14090)



* replace assertions in src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py

* replace assertions in src/transformers/models/marian/convert_marian_to_pytorch.py

* Update src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/marian/convert_marian_to_pytorch.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/marian/convert_marian_to_pytorch.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/marian/convert_marian_to_pytorch.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/marian/convert_marian_to_pytorch.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/marian/convert_marian_to_pytorch.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/marian/convert_marian_to_pytorch.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/marian/convert_marian_to_pytorch.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/marian/convert_marian_to_pytorch.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarskpig <1900012999@pku.edu.cn>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 16d7b70b
......@@ -73,8 +73,12 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
model = LukeModel(config=config).eval()
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
assert len(missing_keys) == 1 and missing_keys[0] == "embeddings.position_ids"
assert all(key.startswith("entity_predictions") or key.startswith("lm_head") for key in unexpected_keys)
if not (len(missing_keys) == 1 and missing_keys[0] == "embeddings.position_ids"):
raise ValueError(f"Missing keys {', '.join(missing_keys)}. Expected only missing embeddings.position_ids")
if not (all(key.startswith("entity_predictions") or key.startswith("lm_head") for key in unexpected_keys)):
raise ValueError(
f"Unexpected keys {', '.join([key for key in unexpected_keys if not (key.startswith('entity_predictions') or key.startswith('lm_head'))])}"
)
# Check outputs
tokenizer = LukeTokenizer.from_pretrained(pytorch_dump_folder_path, task="entity_classification")
......@@ -95,8 +99,12 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
expected_shape = torch.Size((1, 42, 768))
expected_slice = torch.tensor([[0.0037, 0.1368, -0.0091], [0.1099, 0.3329, -0.1095], [0.0765, 0.5335, 0.1179]])
assert outputs.last_hidden_state.shape == expected_shape
assert torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)
if not (outputs.last_hidden_state.shape == expected_shape):
raise ValueError(
f"Outputs.last_hidden_state.shape is {outputs.last_hidden_state.shape}, Expected shape is {expected_shape}"
)
if not torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4):
raise ValueError
# Verify entity hidden states
if model_size == "large":
......@@ -106,8 +114,12 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
expected_shape = torch.Size((1, 1, 768))
expected_slice = torch.tensor([[0.1457, 0.1044, 0.0174]])
assert outputs.entity_last_hidden_state.shape == expected_shape
assert torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)
if not (outputs.entity_last_hidden_state.shape != expected_shape):
raise ValueError(
f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is {expected_shape}"
)
if not torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4):
raise ValueError
# Finally, save our PyTorch model and tokenizer
print("Saving PyTorch model to {}".format(pytorch_dump_folder_path))
......
......@@ -112,7 +112,8 @@ def load_config_from_state_dict(opus_dict):
def find_model_file(dest_dir): # this one better
model_files = list(Path(dest_dir).glob("*.npz"))
assert len(model_files) == 1, model_files
if len(model_files) != 1:
raise ValueError(f"Found more than one model file: {model_files}")
model_file = model_files[0]
return model_file
......@@ -218,9 +219,11 @@ def write_model_card(
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
opus_name: str = convert_hf_name_to_opus_name(hf_model_name)
assert repo_root in ("OPUS-MT-train", "Tatoeba-Challenge")
if repo_root not in ("OPUS-MT-train", "Tatoeba-Challenge"):
raise ValueError(f"Repos root is {repo_root}. Expected either OPUS-MT-train or Tatoeba-Challenge")
opus_readme_path = Path(repo_root).joinpath("models", opus_name, "README.md")
assert opus_readme_path.exists(), f"Readme file {opus_readme_path} not found"
if not (opus_readme_path.exists()):
raise ValueError(f"Readme file {opus_readme_path} not found")
opus_src, opus_tgt = [x.split("+") for x in opus_name.split("-")]
......@@ -321,9 +324,8 @@ def fetch_test_set(test_set_url):
src = lmap(str.strip, lns[::4])
gold = lmap(str.strip, lns[1::4])
mar_model = lmap(str.strip, lns[2::4])
assert (
len(gold) == len(mar_model) == len(src)
), f"Gold, marian and source lengths {len(gold)}, {len(mar_model)}, {len(src)} mismatched"
if not (len(gold) == len(mar_model) == len(src)):
raise ValueError(f"Gold, marian and source lengths {len(gold)}, {len(mar_model)}, {len(src)} mismatched")
os.remove(fname)
return src, mar_model, gold
......@@ -391,7 +393,8 @@ def add_special_tokens_to_vocab(model_dir: Path) -> None:
def check_equal(marian_cfg, k1, k2):
v1, v2 = marian_cfg[k1], marian_cfg[k2]
assert v1 == v2, f"hparams {k1},{k2} differ: {v1} != {v2}"
if v1 != v2:
raise ValueError(f"hparams {k1},{k2} differ: {v1} != {v2}")
def check_marian_cfg_assumptions(marian_cfg):
......@@ -413,7 +416,8 @@ def check_marian_cfg_assumptions(marian_cfg):
}
for k, v in assumed_settings.items():
actual = marian_cfg[k]
assert actual == v, f"Unexpected config value for {k} expected {v} got {actual}"
if actual != v:
raise ValueError(f"Unexpected config value for {k} expected {v} got {actual}")
check_equal(marian_cfg, "transformer-ffn-activation", "transformer-aan-activation")
check_equal(marian_cfg, "transformer-ffn-depth", "transformer-aan-depth")
check_equal(marian_cfg, "transformer-dim-ffn", "transformer-dim-aan")
......@@ -456,22 +460,24 @@ class OpusState:
npz_path = find_model_file(source_dir)
self.state_dict = np.load(npz_path)
cfg = load_config_from_state_dict(self.state_dict)
assert cfg["dim-vocabs"][0] == cfg["dim-vocabs"][1]
assert "Wpos" not in self.state_dict, "Wpos key in state dictionary"
if cfg["dim-vocabs"][0] != cfg["dim-vocabs"][1]:
raise ValueError
if "Wpos" in self.state_dict:
raise ValueError("Wpos key in state dictionary")
self.state_dict = dict(self.state_dict)
self.wemb, self.final_bias = add_emb_entries(self.state_dict["Wemb"], self.state_dict[BIAS_KEY], 1)
self.pad_token_id = self.wemb.shape[0] - 1
cfg["vocab_size"] = self.pad_token_id + 1
# self.state_dict['Wemb'].sha
self.state_keys = list(self.state_dict.keys())
assert "Wtype" not in self.state_dict, "Wtype key in state dictionary"
if "Wtype" in self.state_dict:
raise ValueError("Wtype key in state dictionary")
self._check_layer_entries()
self.source_dir = source_dir
self.cfg = cfg
hidden_size, intermediate_shape = self.state_dict["encoder_l1_ffn_W1"].shape
assert (
hidden_size == cfg["dim-emb"] == 512
), f"Hidden size {hidden_size} and configured size {cfg['dim_emb']} mismatched or not 512"
if hidden_size != 512 or cfg["dim-emb"] != 512:
raise ValueError(f"Hidden size {hidden_size} and configured size {cfg['dim_emb']} mismatched or not 512")
# Process decoder.yml
decoder_yml = cast_marian_config(load_yaml(source_dir / "decoder.yml"))
......@@ -532,10 +538,12 @@ class OpusState:
def load_marian_model(self) -> MarianMTModel:
state_dict, cfg = self.state_dict, self.hf_config
assert cfg.static_position_embeddings, "config.static_position_embeddings should be True"
if not cfg.static_position_embeddings:
raise ValueError("config.static_position_embeddings should be True")
model = MarianMTModel(cfg)
assert "hidden_size" not in cfg.to_dict()
if "hidden_size" in cfg.to_dict():
raise ValueError("hidden_size is in config")
load_layers_(
model.model.encoder.layers,
state_dict,
......@@ -558,13 +566,14 @@ class OpusState:
model.model.decoder.embed_positions.weight = wpos_tensor
if cfg.normalize_embedding:
assert "encoder_emb_ln_scale_pre" in state_dict
if not ("encoder_emb_ln_scale_pre" in state_dict):
raise ValueError("encoder_emb_ln_scale_pre is not in state dictionary")
raise NotImplementedError("Need to convert layernorm_embedding")
assert not self.extra_keys, f"Failed to convert {self.extra_keys}"
assert (
model.model.shared.padding_idx == self.pad_token_id
), f"Padding tokens {model.model.shared.padding_idx} and {self.pad_token_id} mismatched"
if self.extra_keys:
raise ValueError(f"Failed to convert {self.extra_keys}")
if model.model.shared.padding_idx != self.pad_token_id:
raise ValueError(f"Padding tokens {model.model.shared.padding_idx} and {self.pad_token_id} mismatched")
return model
......@@ -588,9 +597,10 @@ def convert(source_dir: Path, dest_dir):
tokenizer.save_pretrained(dest_dir)
opus_state = OpusState(source_dir)
assert opus_state.cfg["vocab_size"] == len(
tokenizer.encoder
), f"Original vocab size {opus_state.cfg['vocab_size']} and new vocab size {len(tokenizer.encoder)} mismatched"
if opus_state.cfg["vocab_size"] != len(tokenizer.encoder):
raise ValueError(
f"Original vocab size {opus_state.cfg['vocab_size']} and new vocab size {len(tokenizer.encoder)} mismatched"
)
# save_json(opus_state.cfg, dest_dir / "marian_original_config.json")
# ^^ Uncomment to save human readable marian config for debugging
......@@ -628,6 +638,7 @@ if __name__ == "__main__":
args = parser.parse_args()
source_dir = Path(args.src)
assert source_dir.exists(), f"Source directory {source_dir} not found"
if not source_dir.exists():
raise ValueError(f"Source directory {source_dir} not found")
dest_dir = f"converted-{source_dir.name}" if args.dest is None else args.dest
convert(source_dir, dest_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment