Unverified Commit 5a0dac53 authored by Teven's avatar Teven Committed by GitHub
Browse files

Empty assert hunt (#6056)



* Fixed empty asserts

* black-reformatted stragglers in templates

* More code quality checks

* Update src/transformers/convert_marian_to_pytorch.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/convert_marian_to_pytorch.py
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>

* removed unused line as per @sshleifer
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 16c22401
...@@ -169,11 +169,15 @@ def load_tf_weights_in_xlnet(model, config, tf_path): ...@@ -169,11 +169,15 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
array = np.transpose(array) array = np.transpose(array)
if isinstance(pointer, list): if isinstance(pointer, list):
# Here we will split the TF weights # Here we will split the TF weights
assert len(pointer) == array.shape[0] assert (
len(pointer) == array.shape[0]
), f"Pointer length {len(pointer)} and array length {array.shape[0]} mismatched"
for i, p_i in enumerate(pointer): for i, p_i in enumerate(pointer):
arr_i = array[i, ...] arr_i = array[i, ...]
try: try:
assert p_i.shape == arr_i.shape assert (
p_i.shape == arr_i.shape
), f"Pointer shape {p_i.shape} and array shape {arr_i.shape} mismatched"
except AssertionError as e: except AssertionError as e:
e.args += (p_i.shape, arr_i.shape) e.args += (p_i.shape, arr_i.shape)
raise raise
...@@ -181,7 +185,9 @@ def load_tf_weights_in_xlnet(model, config, tf_path): ...@@ -181,7 +185,9 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
p_i.data = torch.from_numpy(arr_i) p_i.data = torch.from_numpy(arr_i)
else: else:
try: try:
assert pointer.shape == array.shape assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e: except AssertionError as e:
e.args += (pointer.shape, array.shape) e.args += (pointer.shape, array.shape)
raise raise
......
...@@ -147,7 +147,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -147,7 +147,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
def count_file(self, path, verbose=False, add_eos=False): def count_file(self, path, verbose=False, add_eos=False):
if verbose: if verbose:
logger.info("counting file {} ...".format(path)) logger.info("counting file {} ...".format(path))
assert os.path.exists(path) assert os.path.exists(path), f"Input file {path} not found"
sents = [] sents = []
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
...@@ -233,7 +233,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -233,7 +233,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False): def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False):
if verbose: if verbose:
logger.info("encoding file {} ...".format(path)) logger.info("encoding file {} ...".format(path))
assert os.path.exists(path) assert os.path.exists(path), f"Output file {path} not found"
encoded = [] encoded = []
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
for idx, line in enumerate(f): for idx, line in enumerate(f):
......
...@@ -683,7 +683,8 @@ class SpecialTokensMixin: ...@@ -683,7 +683,8 @@ class SpecialTokensMixin:
for key, value in kwargs.items(): for key, value in kwargs.items():
if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == "additional_special_tokens": if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value) assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple"
assert all(isinstance(t, str) for t in value), "One of the tokens is not a string"
setattr(self, key, value) setattr(self, key, value)
elif isinstance(value, (str, AddedToken)): elif isinstance(value, (str, AddedToken)):
setattr(self, key, value) setattr(self, key, value)
...@@ -752,7 +753,7 @@ class SpecialTokensMixin: ...@@ -752,7 +753,7 @@ class SpecialTokensMixin:
added_tokens = 0 added_tokens = 0
for key, value in special_tokens_dict.items(): for key, value in special_tokens_dict.items():
assert key in self.SPECIAL_TOKENS_ATTRIBUTES assert key in self.SPECIAL_TOKENS_ATTRIBUTES, f"Key {key} is not a special token"
if self.verbose: if self.verbose:
logger.info("Assigning %s to the %s key of the tokenizer", value, key) logger.info("Assigning %s to the %s key of the tokenizer", value, key)
......
...@@ -124,11 +124,15 @@ class SequentialDistributedSampler(Sampler): ...@@ -124,11 +124,15 @@ class SequentialDistributedSampler(Sampler):
# add extra samples to make it evenly divisible # add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))] indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size assert (
len(indices) == self.total_size
), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
# subsample # subsample
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
assert len(indices) == self.num_samples assert (
len(indices) == self.num_samples
), f"Indices length {len(indices)} and and sample number {self.num_samples} mismatched"
return iter(indices) return iter(indices)
...@@ -566,9 +570,11 @@ class Trainer: ...@@ -566,9 +570,11 @@ class Trainer:
# In all cases (even distributed/parallel), self.model is always a reference # In all cases (even distributed/parallel), self.model is always a reference
# to the model we want to save. # to the model we want to save.
if hasattr(model, "module"): if hasattr(model, "module"):
assert model.module is self.model assert (
model.module is self.model
), f"Module {model.module} should be a reference to self.model"
else: else:
assert model is self.model assert model is self.model, f"Model {model} should be a reference to self.model"
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}") output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
......
...@@ -327,9 +327,15 @@ def convert_examples_to_features( ...@@ -327,9 +327,15 @@ def convert_examples_to_features(
segment_ids.append(pad_token_segment_id) segment_ids.append(pad_token_segment_id)
p_mask.append(1) p_mask.append(1)
assert len(input_ids) == max_seq_length assert (
assert len(input_mask) == max_seq_length len(input_ids) == max_seq_length
assert len(segment_ids) == max_seq_length ), f"Input ids and sequence have mismatched lengths {len(input_ids)} and {max_seq_length}"
assert (
len(input_mask) == max_seq_length
), f"Input mask and sequence have mismatched lengths {len(input_mask)} and {max_seq_length}"
assert (
len(segment_ids) == max_seq_length
), f"Segment ids and sequence have mismatched lengths {len(segment_ids)} and {max_seq_length}"
span_is_impossible = example.is_impossible span_is_impossible = example.is_impossible
start_position = None start_position = None
...@@ -626,7 +632,7 @@ def write_predictions( ...@@ -626,7 +632,7 @@ def write_predictions(
if not nbest: if not nbest:
nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
assert len(nbest) >= 1 assert len(nbest) >= 1, "No valid predictions"
total_scores = [] total_scores = []
best_non_null_entry = None best_non_null_entry = None
...@@ -647,7 +653,7 @@ def write_predictions( ...@@ -647,7 +653,7 @@ def write_predictions(
output["end_logit"] = entry.end_logit output["end_logit"] = entry.end_logit
nbest_json.append(output) nbest_json.append(output)
assert len(nbest_json) >= 1 assert len(nbest_json) >= 1, "No valid predictions"
if not version_2_with_negative: if not version_2_with_negative:
all_predictions[example.qas_id] = nbest_json[0]["text"] all_predictions[example.qas_id] = nbest_json[0]["text"]
...@@ -843,8 +849,8 @@ def write_predictions_extended( ...@@ -843,8 +849,8 @@ def write_predictions_extended(
output["end_log_prob"] = entry.end_log_prob output["end_log_prob"] = entry.end_log_prob
nbest_json.append(output) nbest_json.append(output)
assert len(nbest_json) >= 1 assert len(nbest_json) >= 1, "No valid predictions"
assert best_non_null_entry is not None assert best_non_null_entry is not None, "No valid predictions"
score_diff = score_null score_diff = score_null
scores_diff_json[example.qas_id] = score_diff scores_diff_json[example.qas_id] = score_diff
......
...@@ -121,7 +121,9 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path): ...@@ -121,7 +121,9 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
elif m_name == "kernel": elif m_name == "kernel":
array = np.transpose(array) array = np.transpose(array)
try: try:
assert pointer.shape == array.shape assert (
pointer.shape == array.shape
), f"Pointer and array have mismatched shapes {pointer.shape} and {array.shape}"
except AssertionError as e: except AssertionError as e:
e.args += (pointer.shape, array.shape) e.args += (pointer.shape, array.shape)
raise raise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment