Unverified Commit 5a0dac53 authored by Teven's avatar Teven Committed by GitHub
Browse files

Empty assert hunt (#6056)



* Fixed empty asserts

* black-reformatted stragglers in templates

* More code quality checks

* Update src/transformers/convert_marian_to_pytorch.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/convert_marian_to_pytorch.py
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>

* removed unused line as per @sshleifer
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 16c22401
...@@ -81,7 +81,6 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -81,7 +81,6 @@ class TrainCommand(BaseTransformersCLICommand):
self.framework = "tf" if is_tf_available() else "torch" self.framework = "tf" if is_tf_available() else "torch"
os.makedirs(args.output, exist_ok=True) os.makedirs(args.output, exist_ok=True)
assert os.path.isdir(args.output)
self.output = args.output self.output = args.output
self.column_label = args.column_label self.column_label = args.column_label
......
...@@ -166,7 +166,7 @@ def write_model_card( ...@@ -166,7 +166,7 @@ def write_model_card(
extra_markdown = f"### {hf_model_name}\n\n* source languages: {s}\n* target languages: {t}\n* OPUS readme: [{opus_name}]({readme_url})\n" extra_markdown = f"### {hf_model_name}\n\n* source languages: {s}\n* target languages: {t}\n* OPUS readme: [{opus_name}]({readme_url})\n"
# combine with opus markdown # combine with opus markdown
opus_readme_path = Path(f"{repo_path}{opus_name}/README.md") opus_readme_path = Path(f"{repo_path}{opus_name}/README.md")
assert opus_readme_path.exists(), opus_readme_path assert opus_readme_path.exists(), f"Readme file {opus_readme_path} not found"
content = opus_readme_path.open().read() content = opus_readme_path.open().read()
content = content.split("\n# ")[-1] # Get the lowest level 1 header in the README -- the most recent model. content = content.split("\n# ")[-1] # Get the lowest level 1 header in the README -- the most recent model.
content = "*".join(content.split("*")[1:]) content = "*".join(content.split("*")[1:])
...@@ -231,7 +231,9 @@ def fetch_test_set(test_set_url): ...@@ -231,7 +231,9 @@ def fetch_test_set(test_set_url):
src = lmap(str.strip, lns[::4]) src = lmap(str.strip, lns[::4])
gold = lmap(str.strip, lns[1::4]) gold = lmap(str.strip, lns[1::4])
mar_model = lmap(str.strip, lns[2::4]) mar_model = lmap(str.strip, lns[2::4])
assert len(gold) == len(mar_model) == len(src) assert (
len(gold) == len(mar_model) == len(src)
), f"Gold, marian and source lengths {len(gold)}, {len(mar_model)}, {len(src)} mismatched"
os.remove(fname) os.remove(fname)
return src, mar_model, gold return src, mar_model, gold
...@@ -374,20 +376,21 @@ class OpusState: ...@@ -374,20 +376,21 @@ class OpusState:
self.state_dict = np.load(npz_path) self.state_dict = np.load(npz_path)
cfg = load_config_from_state_dict(self.state_dict) cfg = load_config_from_state_dict(self.state_dict)
assert cfg["dim-vocabs"][0] == cfg["dim-vocabs"][1] assert cfg["dim-vocabs"][0] == cfg["dim-vocabs"][1]
assert "Wpos" not in self.state_dict assert "Wpos" not in self.state_dict, "Wpos key in state dictionary"
self.state_dict = dict(self.state_dict) self.state_dict = dict(self.state_dict)
self.wemb, self.final_bias = add_emb_entries(self.state_dict["Wemb"], self.state_dict[BIAS_KEY], 1) self.wemb, self.final_bias = add_emb_entries(self.state_dict["Wemb"], self.state_dict[BIAS_KEY], 1)
self.pad_token_id = self.wemb.shape[0] - 1 self.pad_token_id = self.wemb.shape[0] - 1
cfg["vocab_size"] = self.pad_token_id + 1 cfg["vocab_size"] = self.pad_token_id + 1
# self.state_dict['Wemb'].sha # self.state_dict['Wemb'].sha
self.state_keys = list(self.state_dict.keys()) self.state_keys = list(self.state_dict.keys())
if "Wtype" in self.state_dict: assert "Wtype" not in self.state_dict, "Wtype key in state dictionary"
raise ValueError("found Wtype key")
self._check_layer_entries() self._check_layer_entries()
self.source_dir = source_dir self.source_dir = source_dir
self.cfg = cfg self.cfg = cfg
hidden_size, intermediate_shape = self.state_dict["encoder_l1_ffn_W1"].shape hidden_size, intermediate_shape = self.state_dict["encoder_l1_ffn_W1"].shape
assert hidden_size == cfg["dim-emb"] == 512 assert (
hidden_size == cfg["dim-emb"] == 512
), f"Hidden size {hidden_size} and configured size {cfg['dim_emb']} mismatched or not 512"
# Process decoder.yml # Process decoder.yml
decoder_yml = cast_marian_config(load_yaml(source_dir / "decoder.yml")) decoder_yml = cast_marian_config(load_yaml(source_dir / "decoder.yml"))
...@@ -448,7 +451,7 @@ class OpusState: ...@@ -448,7 +451,7 @@ class OpusState:
def load_marian_model(self) -> MarianMTModel: def load_marian_model(self) -> MarianMTModel:
state_dict, cfg = self.state_dict, self.hf_config state_dict, cfg = self.state_dict, self.hf_config
assert cfg.static_position_embeddings assert cfg.static_position_embeddings, "config.static_position_embeddings should be True"
model = MarianMTModel(cfg) model = MarianMTModel(cfg)
assert "hidden_size" not in cfg.to_dict() assert "hidden_size" not in cfg.to_dict()
...@@ -476,7 +479,9 @@ class OpusState: ...@@ -476,7 +479,9 @@ class OpusState:
raise NotImplementedError("Need to convert layernorm_embedding") raise NotImplementedError("Need to convert layernorm_embedding")
assert not self.extra_keys, f"Failed to convert {self.extra_keys}" assert not self.extra_keys, f"Failed to convert {self.extra_keys}"
assert model.model.shared.padding_idx == self.pad_token_id assert (
model.model.shared.padding_idx == self.pad_token_id
), f"Padding tokens {model.model.shared.padding_idx} and {self.pad_token_id} mismatched"
return model return model
...@@ -500,7 +505,9 @@ def convert(source_dir: Path, dest_dir): ...@@ -500,7 +505,9 @@ def convert(source_dir: Path, dest_dir):
save_tokenizer(tokenizer, dest_dir) save_tokenizer(tokenizer, dest_dir)
opus_state = OpusState(source_dir) opus_state = OpusState(source_dir)
assert opus_state.cfg["vocab_size"] == len(tokenizer.encoder) assert opus_state.cfg["vocab_size"] == len(
tokenizer.encoder
), f"Original vocab size {opus_state.cfg['vocab_size']} and new vocab size {len(tokenizer.encoder)} mismatched"
# save_json(opus_state.cfg, dest_dir / "marian_original_config.json") # save_json(opus_state.cfg, dest_dir / "marian_original_config.json")
# ^^ Save human readable marian config for debugging # ^^ Save human readable marian config for debugging
...@@ -517,7 +524,7 @@ if __name__ == "__main__": ...@@ -517,7 +524,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
source_dir = Path(args.src) source_dir = Path(args.src)
assert source_dir.exists() assert source_dir.exists(), f"Source directory {source_dir} not found"
dest_dir = f"converted-{source_dir.name}" if args.dest is None else args.dest dest_dir = f"converted-{source_dir.name}" if args.dest is None else args.dest
convert(source_dir, dest_dir) convert(source_dir, dest_dir)
......
...@@ -22,7 +22,7 @@ class TextDataset(Dataset): ...@@ -22,7 +22,7 @@ class TextDataset(Dataset):
def __init__( def __init__(
self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, overwrite_cache=False, self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, overwrite_cache=False,
): ):
assert os.path.isfile(file_path) assert os.path.isfile(file_path), f"Input file path {file_path} not found"
block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False) block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
...@@ -82,7 +82,7 @@ class LineByLineTextDataset(Dataset): ...@@ -82,7 +82,7 @@ class LineByLineTextDataset(Dataset):
""" """
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int): def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
assert os.path.isfile(file_path) assert os.path.isfile(file_path), f"Input file path {file_path} not found"
# Here, we do not cache the features, operating under the assumption # Here, we do not cache the features, operating under the assumption
# that we will soon use fast multithreaded tokenizers from the # that we will soon use fast multithreaded tokenizers from the
# `tokenizers` repo everywhere =) # `tokenizers` repo everywhere =)
......
...@@ -51,7 +51,9 @@ if _has_sklearn: ...@@ -51,7 +51,9 @@ if _has_sklearn:
} }
def glue_compute_metrics(task_name, preds, labels): def glue_compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels) assert len(preds) == len(
labels
), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
if task_name == "cola": if task_name == "cola":
return {"mcc": matthews_corrcoef(labels, preds)} return {"mcc": matthews_corrcoef(labels, preds)}
elif task_name == "sst-2": elif task_name == "sst-2":
...@@ -78,7 +80,9 @@ if _has_sklearn: ...@@ -78,7 +80,9 @@ if _has_sklearn:
raise KeyError(task_name) raise KeyError(task_name)
def xnli_compute_metrics(task_name, preds, labels): def xnli_compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels) assert len(preds) == len(
labels
), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
if task_name == "xnli": if task_name == "xnli":
return {"acc": simple_accuracy(preds, labels)} return {"acc": simple_accuracy(preds, labels)}
else: else:
......
...@@ -523,7 +523,7 @@ def compute_predictions_logits( ...@@ -523,7 +523,7 @@ def compute_predictions_logits(
if not nbest: if not nbest:
nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
assert len(nbest) >= 1 assert len(nbest) >= 1, "No valid predictions"
total_scores = [] total_scores = []
best_non_null_entry = None best_non_null_entry = None
...@@ -544,7 +544,7 @@ def compute_predictions_logits( ...@@ -544,7 +544,7 @@ def compute_predictions_logits(
output["end_logit"] = entry.end_logit output["end_logit"] = entry.end_logit
nbest_json.append(output) nbest_json.append(output)
assert len(nbest_json) >= 1 assert len(nbest_json) >= 1, "No valid predictions"
if not version_2_with_negative: if not version_2_with_negative:
all_predictions[example.qas_id] = nbest_json[0]["text"] all_predictions[example.qas_id] = nbest_json[0]["text"]
...@@ -739,8 +739,8 @@ def compute_predictions_log_probs( ...@@ -739,8 +739,8 @@ def compute_predictions_log_probs(
output["end_log_prob"] = entry.end_log_prob output["end_log_prob"] = entry.end_log_prob
nbest_json.append(output) nbest_json.append(output)
assert len(nbest_json) >= 1 assert len(nbest_json) >= 1, "No valid predictions"
assert best_non_null_entry is not None assert best_non_null_entry is not None, "No valid predictions"
score_diff = score_null score_diff = score_null
scores_diff_json[example.qas_id] = score_diff scores_diff_json[example.qas_id] = score_diff
......
...@@ -194,8 +194,12 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -194,8 +194,12 @@ class SingleSentenceClassificationProcessor(DataProcessor):
def add_examples( def add_examples(
self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
): ):
assert labels is None or len(texts_or_text_and_labels) == len(labels) assert labels is None or len(texts_or_text_and_labels) == len(
assert ids is None or len(texts_or_text_and_labels) == len(ids) labels
), f"Text and labels have mismatched lengths {len(texts_or_text_and_labels)} and {len(labels)}"
assert ids is None or len(texts_or_text_and_labels) == len(
ids
), f"Text and ids have mismatched lengths {len(texts_or_text_and_labels)} and {len(ids)}"
if ids is None: if ids is None:
ids = [None] * len(texts_or_text_and_labels) ids = [None] * len(texts_or_text_and_labels)
if labels is None: if labels is None:
......
...@@ -45,7 +45,9 @@ class XnliProcessor(DataProcessor): ...@@ -45,7 +45,9 @@ class XnliProcessor(DataProcessor):
text_a = line[0] text_a = line[0]
text_b = line[1] text_b = line[1]
label = "contradiction" if line[2] == "contradictory" else line[2] label = "contradiction" if line[2] == "contradictory" else line[2]
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) assert isinstance(text_a, str), f"Training input {text_a} is not a string"
assert isinstance(text_b, str), f"Training input {text_b} is not a string"
assert isinstance(label, str), f"Training label {label} is not a string"
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -63,7 +65,9 @@ class XnliProcessor(DataProcessor): ...@@ -63,7 +65,9 @@ class XnliProcessor(DataProcessor):
text_a = line[6] text_a = line[6]
text_b = line[7] text_b = line[7]
label = line[1] label = line[1]
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) assert isinstance(text_a, str), f"Training input {text_a} is not a string"
assert isinstance(text_b, str), f"Training input {text_b} is not a string"
assert isinstance(label, str), f"Training label {label} is not a string"
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
......
...@@ -179,7 +179,9 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path): ...@@ -179,7 +179,9 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
elif m_name == "kernel": elif m_name == "kernel":
array = np.transpose(array) array = np.transpose(array)
try: try:
assert pointer.shape == array.shape assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e: except AssertionError as e:
e.args += (pointer.shape, array.shape) e.args += (pointer.shape, array.shape)
raise raise
......
...@@ -146,7 +146,9 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): ...@@ -146,7 +146,9 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
elif m_name == "kernel": elif m_name == "kernel":
array = np.transpose(array) array = np.transpose(array)
try: try:
assert pointer.shape == array.shape assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e: except AssertionError as e:
e.args += (pointer.shape, array.shape) e.args += (pointer.shape, array.shape)
raise raise
......
...@@ -114,7 +114,9 @@ def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_ ...@@ -114,7 +114,9 @@ def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_
elif m_name == "kernel": elif m_name == "kernel":
array = np.transpose(array) array = np.transpose(array)
try: try:
assert pointer.shape == array.shape, original_name assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e: except AssertionError as e:
e.args += (pointer.shape, array.shape) e.args += (pointer.shape, array.shape)
raise raise
......
...@@ -106,7 +106,9 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): ...@@ -106,7 +106,9 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
num = int(scope_names[1]) num = int(scope_names[1])
pointer = pointer[num] pointer = pointer[num]
try: try:
assert pointer.shape == array.shape assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e: except AssertionError as e:
e.args += (pointer.shape, array.shape) e.args += (pointer.shape, array.shape)
raise raise
......
...@@ -130,7 +130,9 @@ def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path): ...@@ -130,7 +130,9 @@ def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path):
elif m_name == "kernel": elif m_name == "kernel":
array = np.transpose(array) array = np.transpose(array)
try: try:
assert pointer.shape == array.shape assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e: except AssertionError as e:
e.args += (pointer.shape, array.shape) e.args += (pointer.shape, array.shape)
raise raise
......
...@@ -121,12 +121,16 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): ...@@ -121,12 +121,16 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
num = int(scope_names[1]) num = int(scope_names[1])
pointer = pointer[num] pointer = pointer[num]
try: try:
assert pointer.shape == array.shape assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e: except AssertionError as e:
e.args += (pointer.shape, array.shape) e.args += (pointer.shape, array.shape)
raise raise
try: try:
assert pointer.shape == array.shape assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e: except AssertionError as e:
e.args += (pointer.shape, array.shape) e.args += (pointer.shape, array.shape)
raise raise
......
...@@ -131,7 +131,9 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): ...@@ -131,7 +131,9 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
logger.info("Transposing numpy weight of shape {} for {}".format(array.shape, name)) logger.info("Transposing numpy weight of shape {} for {}".format(array.shape, name))
array = np.transpose(array) array = np.transpose(array)
try: try:
assert pointer.shape == array.shape assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e: except AssertionError as e:
e.args += (pointer.shape, array.shape) e.args += (pointer.shape, array.shape)
raise raise
......
...@@ -170,7 +170,9 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer): ...@@ -170,7 +170,9 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer):
) )
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
assert config.hidden_size % config.num_attention_heads == 0 assert (
config.hidden_size % config.num_attention_heads == 0
), f"Hidden size {config.hidden_size} not dividable by number of heads {config.num_attention_heads}"
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
......
...@@ -195,7 +195,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): ...@@ -195,7 +195,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
self.dim = config.dim self.dim = config.dim
self.dropout = tf.keras.layers.Dropout(config.attention_dropout) self.dropout = tf.keras.layers.Dropout(config.attention_dropout)
assert self.dim % self.n_heads == 0 assert self.dim % self.n_heads == 0, f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}"
self.q_lin = tf.keras.layers.Dense( self.q_lin = tf.keras.layers.Dense(
config.dim, kernel_initializer=get_initializer(config.initializer_range), name="q_lin" config.dim, kernel_initializer=get_initializer(config.initializer_range), name="q_lin"
...@@ -311,7 +311,9 @@ class TFTransformerBlock(tf.keras.layers.Layer): ...@@ -311,7 +311,9 @@ class TFTransformerBlock(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
self.activation = config.activation self.activation = config.activation
assert config.dim % config.n_heads == 0 assert (
config.dim % config.n_heads == 0
), f"Hidden size {config.dim} not dividable by number of heads {config.n_heads}"
self.attention = TFMultiHeadSelfAttention(config, name="attention") self.attention = TFMultiHeadSelfAttention(config, name="attention")
self.sa_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="sa_layer_norm") self.sa_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="sa_layer_norm")
...@@ -395,11 +397,11 @@ class TFTransformer(tf.keras.layers.Layer): ...@@ -395,11 +397,11 @@ class TFTransformer(tf.keras.layers.Layer):
hidden_state = layer_outputs[-1] hidden_state = layer_outputs[-1]
if cast_bool_to_primitive(output_attentions) is True: if cast_bool_to_primitive(output_attentions) is True:
assert len(layer_outputs) == 2 assert len(layer_outputs) == 2, f"Incorrect number of outputs {len(layer_outputs)} instead of 2"
attentions = layer_outputs[0] attentions = layer_outputs[0]
all_attentions = all_attentions + (attentions,) all_attentions = all_attentions + (attentions,)
else: else:
assert len(layer_outputs) == 1 assert len(layer_outputs) == 1, f"Incorrect number of outputs {len(layer_outputs)} instead of 1"
# Add last layer # Add last layer
if cast_bool_to_primitive(output_hidden_states) is True: if cast_bool_to_primitive(output_hidden_states) is True:
...@@ -1024,7 +1026,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn ...@@ -1024,7 +1026,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
self.qa_outputs = tf.keras.layers.Dense( self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
) )
assert config.num_labels == 2 assert config.num_labels == 2, f"Incorrect number of labels {config.num_labels} instead of 2"
self.dropout = tf.keras.layers.Dropout(config.qa_dropout) self.dropout = tf.keras.layers.Dropout(config.qa_dropout)
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
......
...@@ -193,7 +193,9 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -193,7 +193,9 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
# check inputs # check inputs
# assert shape_list(lengths)[0] == bs # assert shape_list(lengths)[0] == bs
tf.debugging.assert_equal(shape_list(lengths)[0], bs) tf.debugging.assert_equal(
shape_list(lengths)[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen # assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None) # assert (src_enc is None) == (src_len is None)
...@@ -211,13 +213,17 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -211,13 +213,17 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
position_ids = tf.expand_dims(tf.range(slen), axis=0) position_ids = tf.expand_dims(tf.range(slen), axis=0)
else: else:
# assert shape_list(position_ids) == [bs, slen] # (slen, bs) # assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(shape_list(position_ids), [bs, slen]) tf.debugging.assert_equal(
shape_list(position_ids), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1) # position_ids = position_ids.transpose(0, 1)
# langs # langs
if langs is not None: if langs is not None:
# assert shape_list(langs) == [bs, slen] # (slen, bs) # assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(shape_list(langs), [bs, slen]) tf.debugging.assert_equal(
shape_list(langs), [bs, slen]
), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched"
# langs = langs.transpose(0, 1) # langs = langs.transpose(0, 1)
# Prepare head mask if needed # Prepare head mask if needed
......
...@@ -77,7 +77,9 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -77,7 +77,9 @@ class TFAttention(tf.keras.layers.Layer):
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0 assert (
n_state % config.n_head == 0
), f"Hidden dimension {n_state} not dividable by number of heads {config.n_head}"
self.n_ctx = n_ctx self.n_ctx = n_ctx
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
......
...@@ -493,8 +493,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -493,8 +493,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len) bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len)
if bsz is not None: if bsz is not None:
# With bi_data, the batch size should be divisible by 2. assert bsz % 2 == 0, f"With bi_data, the batch size {bsz} should be divisible by 2"
assert bsz % 2 == 0
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2) fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2) bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
else: else:
......
...@@ -155,7 +155,9 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path): ...@@ -155,7 +155,9 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
p_i.data = torch.from_numpy(arr_i) p_i.data = torch.from_numpy(arr_i)
else: else:
try: try:
assert pointer.shape == array.shape assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e: except AssertionError as e:
e.args += (pointer.shape, array.shape) e.args += (pointer.shape, array.shape)
raise raise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment