Unverified Commit 239ace15 authored by Xiaoli Wang's avatar Xiaoli Wang Committed by GitHub
Browse files

Fix TypeError: Object of type int64 is not JSON serializable (#24340)

* Fix TypeError: Object of type int64 is not JSON serializable

* Convert numpy.float64 and numpy.int64 to float and int for json serialization

* Black reformatted examples/pytorch/token-classification/run_ner_no_trainer.py

* * make style
parent ac19871c
...@@ -28,6 +28,7 @@ from pathlib import Path ...@@ -28,6 +28,7 @@ from pathlib import Path
import datasets import datasets
import evaluate import evaluate
import numpy as np
import torch import torch
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
...@@ -777,6 +778,12 @@ def main(): ...@@ -777,6 +778,12 @@ def main():
if args.with_tracking: if args.with_tracking:
all_results.update({"train_loss": total_loss.item() / len(train_dataloader)}) all_results.update({"train_loss": total_loss.item() / len(train_dataloader)})
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
# Convert all float64 & int64 type numbers to float & int for json serialization
for key, value in all_results.items():
if isinstance(value, np.float64):
all_results[key] = float(value)
elif isinstance(value, np.int64):
all_results[key] = int(value)
json.dump(all_results, f) json.dump(all_results, f)
......
...@@ -60,7 +60,7 @@ class EndOfFunctionCriteria(StoppingCriteria): ...@@ -60,7 +60,7 @@ class EndOfFunctionCriteria(StoppingCriteria):
decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :]) decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])
done = [] done = []
for decoded_generation in decoded_generations: for decoded_generation in decoded_generations:
done.append(any([stop_string in decoded_generation for stop_string in self.eof_strings])) done.append(any(stop_string in decoded_generation for stop_string in self.eof_strings))
return all(done) return all(done)
......
...@@ -17,7 +17,7 @@ class FSNERTokenizerUtils(object): ...@@ -17,7 +17,7 @@ class FSNERTokenizerUtils(object):
`transformers.tokenization_utils_base.BatchEncoding` dict with additional keys and values for start_token_id, end_token_id and sizes of example lists for each entity type `transformers.tokenization_utils_base.BatchEncoding` dict with additional keys and values for start_token_id, end_token_id and sizes of example lists for each entity type
""" """
if isinstance(x, list) and all([isinstance(_x, list) for _x in x]): if isinstance(x, list) and all(isinstance(_x, list) for _x in x):
d = None d = None
for l in x: for l in x:
t = self.tokenizer( t = self.tokenizer(
...@@ -37,7 +37,7 @@ class FSNERTokenizerUtils(object): ...@@ -37,7 +37,7 @@ class FSNERTokenizerUtils(object):
d["start_token_id"] = torch.tensor(self.tokenizer.convert_tokens_to_ids("[E]")) d["start_token_id"] = torch.tensor(self.tokenizer.convert_tokens_to_ids("[E]"))
d["end_token_id"] = torch.tensor(self.tokenizer.convert_tokens_to_ids("[/E]")) d["end_token_id"] = torch.tensor(self.tokenizer.convert_tokens_to_ids("[/E]"))
elif isinstance(x, list) and all([isinstance(_x, str) for _x in x]): elif isinstance(x, list) and all(isinstance(_x, str) for _x in x):
d = self.tokenizer( d = self.tokenizer(
x, x,
padding="max_length", padding="max_length",
......
...@@ -50,7 +50,7 @@ def _get_single_answer(example): ...@@ -50,7 +50,7 @@ def _get_single_answer(example):
answer["remove_it"] = False answer["remove_it"] = False
cols = ["start_token", "end_token", "start_byte", "end_byte", "text"] cols = ["start_token", "end_token", "start_byte", "end_byte", "text"]
if not all([isinstance(answer[k], list) for k in cols]): if not all(isinstance(answer[k], list) for k in cols):
raise ValueError("Issue in ID", example["id"]) raise ValueError("Issue in ID", example["id"])
return answer return answer
......
...@@ -610,7 +610,7 @@ def main(): ...@@ -610,7 +610,7 @@ def main():
predicted_sequence = [label_list[0]] * len(true_tags) predicted_sequence = [label_list[0]] * len(true_tags)
for _, span, label in sorted(predictions, key=lambda o: o[0], reverse=True): for _, span, label in sorted(predictions, key=lambda o: o[0], reverse=True):
if all([o == label_list[0] for o in predicted_sequence[span[0] : span[1]]]): if all(o == label_list[0] for o in predicted_sequence[span[0] : span[1]]):
predicted_sequence[span[0]] = label predicted_sequence[span[0]] = label
if span[1] - span[0] > 1: if span[1] - span[0] > 1:
predicted_sequence[span[0] + 1 : span[1]] = [label] * (span[1] - span[0] - 1) predicted_sequence[span[0] + 1 : span[1]] = [label] * (span[1] - span[0] - 1)
......
...@@ -554,8 +554,8 @@ class Matcher(object): ...@@ -554,8 +554,8 @@ class Matcher(object):
assert thresholds[0] > 0 assert thresholds[0] > 0
thresholds.insert(0, -float("inf")) thresholds.insert(0, -float("inf"))
thresholds.append(float("inf")) thresholds.append(float("inf"))
assert all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])]) assert all(low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:]))
assert all([label_i in [-1, 0, 1] for label_i in labels]) assert all(label_i in [-1, 0, 1] for label_i in labels)
assert len(labels) == len(thresholds) - 1 assert len(labels) == len(thresholds) - 1
self.thresholds = thresholds self.thresholds = thresholds
self.labels = labels self.labels = labels
......
...@@ -554,8 +554,8 @@ class Matcher(object): ...@@ -554,8 +554,8 @@ class Matcher(object):
assert thresholds[0] > 0 assert thresholds[0] > 0
thresholds.insert(0, -float("inf")) thresholds.insert(0, -float("inf"))
thresholds.append(float("inf")) thresholds.append(float("inf"))
assert all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])]) assert all(low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:]))
assert all([label_i in [-1, 0, 1] for label_i in labels]) assert all(label_i in [-1, 0, 1] for label_i in labels)
assert len(labels) == len(thresholds) - 1 assert len(labels) == len(thresholds) - 1
self.thresholds = thresholds self.thresholds = thresholds
self.labels = labels self.labels = labels
......
...@@ -110,7 +110,7 @@ class MinLengthLogitsProcessor(LogitsProcessor): ...@@ -110,7 +110,7 @@ class MinLengthLogitsProcessor(LogitsProcessor):
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]): if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
self.min_length = min_length self.min_length = min_length
...@@ -147,7 +147,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): ...@@ -147,7 +147,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]): if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
self.prompt_length_to_skip = prompt_length_to_skip self.prompt_length_to_skip = prompt_length_to_skip
...@@ -731,7 +731,7 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor): ...@@ -731,7 +731,7 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
bad_words_ids = list( bad_words_ids = list(
filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids) filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
) )
# Forbidding a sequence is equivalent to setting its bias to -inf # Forbidding a sequence is equivalent to setting its bias to -inf
......
...@@ -318,7 +318,7 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor): ...@@ -318,7 +318,7 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
self.bad_word_seqs_ids = tf.ragged.constant(bad_words_ids).to_tensor(default_value=-1) self.bad_word_seqs_ids = tf.ragged.constant(bad_words_ids).to_tensor(default_value=-1)
# 2. a tensor with the unpadded length of each forbidden sequence, for quick length comparisons # 2. a tensor with the unpadded length of each forbidden sequence, for quick length comparisons
bad_word_seqs_len = [len(bad_words) for bad_words in bad_words_ids] bad_word_seqs_len = [len(bad_words) for bad_words in bad_words_ids]
if any([word_len == 0 for word_len in bad_word_seqs_len]): if any(word_len == 0 for word_len in bad_word_seqs_len):
raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list") raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")
self.bad_word_seqs_len = tf.convert_to_tensor(bad_word_seqs_len, dtype=tf.int32) self.bad_word_seqs_len = tf.convert_to_tensor(bad_word_seqs_len, dtype=tf.int32)
# 3. a tensor containing the last token for each sequence, for easy access to the tokens that may be banned # 3. a tensor containing the last token for each sequence, for easy access to the tokens that may be banned
......
...@@ -1638,7 +1638,7 @@ class TFGenerationMixin: ...@@ -1638,7 +1638,7 @@ class TFGenerationMixin:
# TODO (Joao): fix cache format or find programatic way to detect cache index # TODO (Joao): fix cache format or find programatic way to detect cache index
# GPT2 and other models has a slightly different cache structure, with a different batch axis # GPT2 and other models has a slightly different cache structure, with a different batch axis
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self) model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0 cache_batch_axis = 1 if any(model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")) else 0
# some models, like XLNet, need more than the last token in the presence of past_key_values # some models, like XLNet, need more than the last token in the presence of past_key_values
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys()) needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
...@@ -1922,7 +1922,7 @@ class TFGenerationMixin: ...@@ -1922,7 +1922,7 @@ class TFGenerationMixin:
# TODO (Joao): fix cache format or find programatic way to detect cache index # TODO (Joao): fix cache format or find programatic way to detect cache index
# GPT2 and other models has a slightly different cache structure, with a different batch axis # GPT2 and other models has a slightly different cache structure, with a different batch axis
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self) model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0 cache_batch_axis = 1 if any(model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")) else 0
# some models, like XLNet, need more than the last token in the presence of past_key_values # some models, like XLNet, need more than the last token in the presence of past_key_values
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys()) needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
...@@ -2265,7 +2265,7 @@ class TFGenerationMixin: ...@@ -2265,7 +2265,7 @@ class TFGenerationMixin:
# TODO (Joao): fix cache format or find programatic way to detect cache index # TODO (Joao): fix cache format or find programatic way to detect cache index
# GPT2 and other models has a slightly different cache structure, with a different batch axis # GPT2 and other models has a slightly different cache structure, with a different batch axis
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self) model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0 cache_batch_axis = 1 if any(model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")) else 0
# some models, like XLNet, need more than the last token in the presence of past_key_values # some models, like XLNet, need more than the last token in the presence of past_key_values
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys()) needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
...@@ -2779,7 +2779,7 @@ class TFGenerationMixin: ...@@ -2779,7 +2779,7 @@ class TFGenerationMixin:
# TODO (Joao): fix cache format or find programatic way to detect cache index # TODO (Joao): fix cache format or find programatic way to detect cache index
# GPT2 and other models has a slightly different cache structure, with a different batch axis # GPT2 and other models has a slightly different cache structure, with a different batch axis
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self) model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0 cache_batch_axis = 1 if any(model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")) else 0
# 2. init `attentions`, `hidden_states`, and `scores` tuples # 2. init `attentions`, `hidden_states`, and `scores` tuples
scores = [] if (return_dict_in_generate and output_scores) else None scores = [] if (return_dict_in_generate and output_scores) else None
......
...@@ -144,7 +144,7 @@ class KerasMetricCallback(Callback): ...@@ -144,7 +144,7 @@ class KerasMetricCallback(Callback):
@staticmethod @staticmethod
def _concatenate_batches(batches, padding_index=-100): def _concatenate_batches(batches, padding_index=-100):
# If all batches are unidimensional or same length, do a simple concatenation # If all batches are unidimensional or same length, do a simple concatenation
if batches[0].ndim == 1 or all([batch.shape[1] == batches[0].shape[1] for batch in batches]): if batches[0].ndim == 1 or all(batch.shape[1] == batches[0].shape[1] for batch in batches):
return np.concatenate(batches, axis=0) return np.concatenate(batches, axis=0)
# Welp, they're not the same length. Let's do some padding # Welp, they're not the same length. Let's do some padding
......
...@@ -78,7 +78,7 @@ def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name ...@@ -78,7 +78,7 @@ def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name
for var_name in state_dict: for var_name in state_dict:
tf_name = to_tf_var_name(var_name) tf_name = to_tf_var_name(var_name)
torch_tensor = state_dict[var_name].numpy() torch_tensor = state_dict[var_name].numpy()
if any([x in var_name for x in tensors_to_transpose]): if any(x in var_name for x in tensors_to_transpose):
torch_tensor = torch_tensor.T torch_tensor = torch_tensor.T
tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
tf.keras.backend.set_value(tf_var, torch_tensor) tf.keras.backend.set_value(tf_var, torch_tensor)
......
...@@ -104,7 +104,7 @@ def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPeg ...@@ -104,7 +104,7 @@ def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPeg
new_k = rename_state_dict_key(k, patterns) new_k = rename_state_dict_key(k, patterns)
if new_k not in state_dict: if new_k not in state_dict:
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})") raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
if any([True if i in k else False for i in ["dense", "query", "key", "value"]]): if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
v = v.T v = v.T
mapping[new_k] = torch.from_numpy(v) mapping[new_k] = torch.from_numpy(v)
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}" assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
...@@ -117,7 +117,7 @@ def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPeg ...@@ -117,7 +117,7 @@ def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPeg
new_k = rename_state_dict_key(k, patterns) new_k = rename_state_dict_key(k, patterns)
if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings": if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings":
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})") raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
if any([True if i in k else False for i in ["dense", "query", "key", "value"]]): if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
v = v.T v = v.T
mapping[new_k] = torch.from_numpy(v) mapping[new_k] = torch.from_numpy(v)
if k != "pegasus/embeddings/position_embeddings": if k != "pegasus/embeddings/position_embeddings":
...@@ -147,7 +147,7 @@ def get_tf_weights_as_numpy(path) -> Dict: ...@@ -147,7 +147,7 @@ def get_tf_weights_as_numpy(path) -> Dict:
tf_weights = {} tf_weights = {}
ignore_name = ["global_step"] ignore_name = ["global_step"]
for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"): for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
skip_key = any([pat in name for pat in ignore_name]) skip_key = any(pat in name for pat in ignore_name)
if skip_key: if skip_key:
continue continue
array = tf.train.load_variable(path, name) array = tf.train.load_variable(path, name)
......
...@@ -2485,9 +2485,9 @@ class DetaMatcher(object): ...@@ -2485,9 +2485,9 @@ class DetaMatcher(object):
thresholds.insert(0, -float("inf")) thresholds.insert(0, -float("inf"))
thresholds.append(float("inf")) thresholds.append(float("inf"))
# Currently torchscript does not support all + generator # Currently torchscript does not support all + generator
if not all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])]): if not all(low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])):
raise ValueError("Thresholds should be sorted.") raise ValueError("Thresholds should be sorted.")
if not all([l in [-1, 0, 1] for l in labels]): if not all(l in [-1, 0, 1] for l in labels):
raise ValueError("All labels should be either -1, 0 or 1") raise ValueError("All labels should be either -1, 0 or 1")
if len(labels) != len(thresholds) - 1: if len(labels) != len(thresholds) - 1:
raise ValueError("Number of labels should be equal to number of thresholds - 1") raise ValueError("Number of labels should be equal to number of thresholds - 1")
......
...@@ -379,11 +379,9 @@ class CustomDPRReaderTokenizerMixin: ...@@ -379,11 +379,9 @@ class CustomDPRReaderTokenizerMixin:
if length > max_answer_length: if length > max_answer_length:
raise ValueError(f"Span is too long: {length} > {max_answer_length}") raise ValueError(f"Span is too long: {length} > {max_answer_length}")
if any( if any(
[ start_index <= prev_start_index <= prev_end_index <= end_index
start_index <= prev_start_index <= prev_end_index <= end_index or prev_start_index <= start_index <= end_index <= prev_end_index
or prev_start_index <= start_index <= end_index <= prev_end_index for (prev_start_index, prev_end_index) in chosen_span_intervals
for (prev_start_index, prev_end_index) in chosen_span_intervals
]
): ):
continue continue
chosen_span_intervals.append((start_index, end_index)) chosen_span_intervals.append((start_index, end_index))
......
...@@ -377,11 +377,9 @@ class CustomDPRReaderTokenizerMixin: ...@@ -377,11 +377,9 @@ class CustomDPRReaderTokenizerMixin:
length = end_index - start_index + 1 length = end_index - start_index + 1
assert length <= max_answer_length, f"Span is too long: {length} > {max_answer_length}" assert length <= max_answer_length, f"Span is too long: {length} > {max_answer_length}"
if any( if any(
[ start_index <= prev_start_index <= prev_end_index <= end_index
start_index <= prev_start_index <= prev_end_index <= end_index or prev_start_index <= start_index <= end_index <= prev_end_index
or prev_start_index <= start_index <= end_index <= prev_end_index for (prev_start_index, prev_end_index) in chosen_span_intervals
for (prev_start_index, prev_end_index) in chosen_span_intervals
]
): ):
continue continue
chosen_span_intervals.append((start_index, end_index)) chosen_span_intervals.append((start_index, end_index))
......
...@@ -90,7 +90,7 @@ def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict: ...@@ -90,7 +90,7 @@ def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict:
tf_weights = {} tf_weights = {}
ignore_name = ["Adafactor", "global_step"] ignore_name = ["Adafactor", "global_step"]
for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"): for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
skip_key = any([pat in name for pat in ignore_name]) skip_key = any(pat in name for pat in ignore_name)
if skip_key: if skip_key:
continue continue
array = tf.train.load_variable(path, name) array = tf.train.load_variable(path, name)
......
...@@ -115,7 +115,7 @@ class SamProcessor(ProcessorMixin): ...@@ -115,7 +115,7 @@ class SamProcessor(ProcessorMixin):
for point, original_size in zip(input_points, original_sizes) for point, original_size in zip(input_points, original_sizes)
] ]
# check that all arrays have the same shape # check that all arrays have the same shape
if not all([point.shape == input_points[0].shape for point in input_points]): if not all(point.shape == input_points[0].shape for point in input_points):
if input_labels is not None: if input_labels is not None:
input_points, input_labels = self._pad_points_and_labels(input_points, input_labels) input_points, input_labels = self._pad_points_and_labels(input_points, input_labels)
......
...@@ -647,7 +647,7 @@ class GenerationIntegrationTestsMixin: ...@@ -647,7 +647,7 @@ class GenerationIntegrationTestsMixin:
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
unpadded_correct_condition = expectation == len(generated_tokens[0]) unpadded_correct_condition = expectation == len(generated_tokens[0])
padded_correct_condition = expectation < len(generated_tokens[0]) and all( padded_correct_condition = expectation < len(generated_tokens[0]) and all(
[token == model.config.pad_token_id for token in generated_tokens[0][expectation:]] token == model.config.pad_token_id for token in generated_tokens[0][expectation:]
) )
self.assertTrue(unpadded_correct_condition or padded_correct_condition) self.assertTrue(unpadded_correct_condition or padded_correct_condition)
...@@ -655,7 +655,7 @@ class GenerationIntegrationTestsMixin: ...@@ -655,7 +655,7 @@ class GenerationIntegrationTestsMixin:
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
unpadded_correct_condition = expectation == len(generated_tokens[0]) unpadded_correct_condition = expectation == len(generated_tokens[0])
padded_correct_condition = expectation < len(generated_tokens[0]) and all( padded_correct_condition = expectation < len(generated_tokens[0]) and all(
[token == model.config.pad_token_id for token in generated_tokens[0][expectation:]] token == model.config.pad_token_id for token in generated_tokens[0][expectation:]
) )
self.assertTrue(unpadded_correct_condition or padded_correct_condition) self.assertTrue(unpadded_correct_condition or padded_correct_condition)
......
...@@ -521,7 +521,7 @@ class CodeGenModelLanguageGenerationTest(unittest.TestCase): ...@@ -521,7 +521,7 @@ class CodeGenModelLanguageGenerationTest(unittest.TestCase):
self.assertEqual(output_str, EXPECTED_OUTPUT_STR) self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
self.assertTrue( self.assertTrue(
all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))]) all(output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs)))
) # token_type_ids should change output ) # token_type_ids should change output
@is_flaky(max_attempts=3, description="measure of timing is somehow flaky.") @is_flaky(max_attempts=3, description="measure of timing is somehow flaky.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment