Unverified Commit acc3bd9d authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Enforce string-formatting with f-strings (#10980)



* First third

* Styling and fix mistake

* Quality

* All the rest

* Treat %s and %d

* typo

* Missing )

* Apply suggestions from code review
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent d0b3797a
...@@ -758,9 +758,7 @@ class Benchmark(ABC): ...@@ -758,9 +758,7 @@ class Benchmark(ABC):
if self.args.env_print: if self.args.env_print:
self.print_fn("\n" + 20 * "=" + ("ENVIRONMENT INFORMATION").center(40) + 20 * "=") self.print_fn("\n" + 20 * "=" + ("ENVIRONMENT INFORMATION").center(40) + 20 * "=")
self.print_fn( self.print_fn("\n".join([f"- {prop}: {val}" for prop, val in self.environment_info.items()]) + "\n")
"\n".join(["- {}: {}".format(prop, val) for prop, val in self.environment_info.items()]) + "\n"
)
if self.args.save_to_csv: if self.args.save_to_csv:
with open(self.args.env_info_csv_file, mode="w", newline="") as csv_file: with open(self.args.env_info_csv_file, mode="w", newline="") as csv_file:
...@@ -888,9 +886,7 @@ class Benchmark(ABC): ...@@ -888,9 +886,7 @@ class Benchmark(ABC):
self.print_fn("Saving results to csv.") self.print_fn("Saving results to csv.")
with open(filename, mode="w") as csv_file: with open(filename, mode="w") as csv_file:
assert len(self.args.model_names) > 0, "At least 1 model should be defined, but got {}".format( assert len(self.args.model_names) > 0, f"At least 1 model should be defined, but got {self.model_names}"
self.model_names
)
fieldnames = ["model", "batch_size", "sequence_length"] fieldnames = ["model", "batch_size", "sequence_length"]
writer = csv.DictWriter(csv_file, fieldnames=fieldnames + ["result"]) writer = csv.DictWriter(csv_file, fieldnames=fieldnames + ["result"])
......
...@@ -76,7 +76,7 @@ class ConvertCommand(BaseTransformersCLICommand): ...@@ -76,7 +76,7 @@ class ConvertCommand(BaseTransformersCLICommand):
): ):
self._logger = logging.get_logger("transformers-cli/converting") self._logger = logging.get_logger("transformers-cli/converting")
self._logger.info("Loading model {}".format(model_type)) self._logger.info(f"Loading model {model_type}")
self._model_type = model_type self._model_type = model_type
self._tf_checkpoint = tf_checkpoint self._tf_checkpoint = tf_checkpoint
self._pytorch_dump_output = pytorch_dump_output self._pytorch_dump_output = pytorch_dump_output
......
...@@ -56,8 +56,8 @@ class EnvironmentCommand(BaseTransformersCLICommand): ...@@ -56,8 +56,8 @@ class EnvironmentCommand(BaseTransformersCLICommand):
"`transformers` version": version, "`transformers` version": version,
"Platform": platform.platform(), "Platform": platform.platform(),
"Python version": platform.python_version(), "Python version": platform.python_version(),
"PyTorch version (GPU?)": "{} ({})".format(pt_version, pt_cuda_available), "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
"Tensorflow version (GPU?)": "{} ({})".format(tf_version, tf_cuda_available), "Tensorflow version (GPU?)": f"{tf_version} ({tf_cuda_available})",
"Using GPU in script?": "<fill in>", "Using GPU in script?": "<fill in>",
"Using distributed or parallel set-up in script?": "<fill in>", "Using distributed or parallel set-up in script?": "<fill in>",
} }
...@@ -69,4 +69,4 @@ class EnvironmentCommand(BaseTransformersCLICommand): ...@@ -69,4 +69,4 @@ class EnvironmentCommand(BaseTransformersCLICommand):
@staticmethod @staticmethod
def format_dict(d): def format_dict(d):
return "\n".join(["- {}: {}".format(prop, val) for prop, val in d.items()]) + "\n" return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
...@@ -31,8 +31,8 @@ def try_infer_format_from_ext(path: str): ...@@ -31,8 +31,8 @@ def try_infer_format_from_ext(path: str):
return ext return ext
raise Exception( raise Exception(
"Unable to determine file format from file extension {}. " f"Unable to determine file format from file extension {path}. "
"Please provide the format through --format {}".format(path, PipelineDataFormat.SUPPORTED_FORMATS) f"Please provide the format through --format {PipelineDataFormat.SUPPORTED_FORMATS}"
) )
...@@ -105,6 +105,6 @@ class RunCommand(BaseTransformersCLICommand): ...@@ -105,6 +105,6 @@ class RunCommand(BaseTransformersCLICommand):
# Saving data # Saving data
if self._nlp.binary_output: if self._nlp.binary_output:
binary_path = self._reader.save_binary(outputs) binary_path = self._reader.save_binary(outputs)
logger.warning("Current pipeline requires output to be in binary format, saving at {}".format(binary_path)) logger.warning(f"Current pipeline requires output to be in binary format, saving at {binary_path}")
else: else:
self._reader.save(outputs) self._reader.save(outputs)
...@@ -133,7 +133,7 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -133,7 +133,7 @@ class ServeCommand(BaseTransformersCLICommand):
"Or install FastAPI and unicorn separately." "Or install FastAPI and unicorn separately."
) )
else: else:
logger.info("Serving model over {}:{}".format(host, port)) logger.info(f"Serving model over {host}:{port}")
self._app = FastAPI( self._app = FastAPI(
routes=[ routes=[
APIRoute( APIRoute(
......
...@@ -104,7 +104,7 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -104,7 +104,7 @@ class TrainCommand(BaseTransformersCLICommand):
self.column_text = args.column_text self.column_text = args.column_text
self.column_id = args.column_id self.column_id = args.column_id
self.logger.info("Loading {} pipeline for {}".format(args.task, args.model)) self.logger.info(f"Loading {args.task} pipeline for {args.model}")
if args.task == "text_classification": if args.task == "text_classification":
self.pipeline = TextClassificationPipeline.from_pretrained(args.model) self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
elif args.task == "token_classification": elif args.task == "token_classification":
...@@ -112,7 +112,7 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -112,7 +112,7 @@ class TrainCommand(BaseTransformersCLICommand):
elif args.task == "question_answering": elif args.task == "question_answering":
raise NotImplementedError raise NotImplementedError
self.logger.info("Loading dataset from {}".format(args.train_data)) self.logger.info(f"Loading dataset from {args.train_data}")
self.train_dataset = Processor.create_from_csv( self.train_dataset = Processor.create_from_csv(
args.train_data, args.train_data,
column_label=args.column_label, column_label=args.column_label,
...@@ -122,7 +122,7 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -122,7 +122,7 @@ class TrainCommand(BaseTransformersCLICommand):
) )
self.valid_dataset = None self.valid_dataset = None
if args.validation_data: if args.validation_data:
self.logger.info("Loading validation dataset from {}".format(args.validation_data)) self.logger.info(f"Loading validation dataset from {args.validation_data}")
self.valid_dataset = Processor.create_from_csv( self.valid_dataset = Processor.create_from_csv(
args.validation_data, args.validation_data,
column_label=args.column_label, column_label=args.column_label,
......
...@@ -99,15 +99,15 @@ class ANSI: ...@@ -99,15 +99,15 @@ class ANSI:
@classmethod @classmethod
def bold(cls, s): def bold(cls, s):
return "{}{}{}".format(cls._bold, s, cls._reset) return f"{cls._bold}{s}{cls._reset}"
@classmethod @classmethod
def red(cls, s): def red(cls, s):
return "{}{}{}".format(cls._bold + cls._red, s, cls._reset) return f"{cls._bold}{cls._red}{s}{cls._reset}"
@classmethod @classmethod
def gray(cls, s): def gray(cls, s):
return "{}{}{}".format(cls._gray, s, cls._reset) return f"{cls._gray}{s}{cls._reset}"
def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
...@@ -268,8 +268,8 @@ class RepoCreateCommand(BaseUserCommand): ...@@ -268,8 +268,8 @@ class RepoCreateCommand(BaseUserCommand):
user, _ = self._api.whoami(token) user, _ = self._api.whoami(token)
namespace = self.args.organization if self.args.organization is not None else user namespace = self.args.organization if self.args.organization is not None else user
full_name = f"{namespace}/{self.args.name}"
print("You are about to create {}".format(ANSI.bold(namespace + "/" + self.args.name))) print(f"You are about to create {ANSI.bold(full_name)}")
if not self.args.yes: if not self.args.yes:
choice = input("Proceed? [Y/n] ").lower() choice = input("Proceed? [Y/n] ").lower()
...@@ -283,7 +283,7 @@ class RepoCreateCommand(BaseUserCommand): ...@@ -283,7 +283,7 @@ class RepoCreateCommand(BaseUserCommand):
print(ANSI.red(e.response.text)) print(ANSI.red(e.response.text))
exit(1) exit(1)
print("\nYour repo now lives at:") print("\nYour repo now lives at:")
print(" {}".format(ANSI.bold(url))) print(f" {ANSI.bold(url)}")
print("\nYou can clone it locally with the command below," " and commit/push as usual.") print("\nYou can clone it locally with the command below," " and commit/push as usual.")
print(f"\n git clone {url}") print(f"\n git clone {url}")
print("") print("")
...@@ -328,16 +328,15 @@ class UploadCommand(BaseUserCommand): ...@@ -328,16 +328,15 @@ class UploadCommand(BaseUserCommand):
filename = self.args.filename if self.args.filename is not None else os.path.basename(local_path) filename = self.args.filename if self.args.filename is not None else os.path.basename(local_path)
files = [(local_path, filename)] files = [(local_path, filename)]
else: else:
raise ValueError("Not a valid file or directory: {}".format(local_path)) raise ValueError(f"Not a valid file or directory: {local_path}")
if sys.platform == "win32": if sys.platform == "win32":
files = [(filepath, filename.replace(os.sep, "/")) for filepath, filename in files] files = [(filepath, filename.replace(os.sep, "/")) for filepath, filename in files]
if len(files) > UPLOAD_MAX_FILES: if len(files) > UPLOAD_MAX_FILES:
print( print(
"About to upload {} files to S3. This is probably wrong. Please filter files before uploading.".format( f"About to upload {ANSI.bold(len(files))} files to S3. This is probably wrong. Please filter files "
ANSI.bold(len(files)) "before uploading."
)
) )
exit(1) exit(1)
...@@ -346,9 +345,8 @@ class UploadCommand(BaseUserCommand): ...@@ -346,9 +345,8 @@ class UploadCommand(BaseUserCommand):
for filepath, filename in files: for filepath, filename in files:
print( print(
"About to upload file {} to S3 under filename {} and namespace {}".format( f"About to upload file {ANSI.bold(filepath)} to S3 under filename {ANSI.bold(filename)} and namespace "
ANSI.bold(filepath), ANSI.bold(filename), ANSI.bold(namespace) f"{ANSI.bold(namespace)}"
)
) )
if not self.args.yes: if not self.args.yes:
......
...@@ -267,7 +267,7 @@ class PretrainedConfig(object): ...@@ -267,7 +267,7 @@ class PretrainedConfig(object):
try: try:
setattr(self, key, value) setattr(self, key, value)
except AttributeError as err: except AttributeError as err:
logger.error("Can't set {} with value {} for {}".format(key, value, self)) logger.error(f"Can't set {key} with value {value} for {self}")
raise err raise err
@property @property
...@@ -296,7 +296,7 @@ class PretrainedConfig(object): ...@@ -296,7 +296,7 @@ class PretrainedConfig(object):
@num_labels.setter @num_labels.setter
def num_labels(self, num_labels: int): def num_labels(self, num_labels: int):
if self.id2label is None or len(self.id2label) != num_labels: if self.id2label is None or len(self.id2label) != num_labels:
self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)} self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
def save_pretrained(self, save_directory: Union[str, os.PathLike]): def save_pretrained(self, save_directory: Union[str, os.PathLike]):
...@@ -309,7 +309,7 @@ class PretrainedConfig(object): ...@@ -309,7 +309,7 @@ class PretrainedConfig(object):
Directory where the configuration JSON file will be saved (will be created if it does not exist). Directory where the configuration JSON file will be saved (will be created if it does not exist).
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory)) raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_config_file = os.path.join(save_directory, CONFIG_NAME) output_config_file = os.path.join(save_directory, CONFIG_NAME)
...@@ -467,16 +467,16 @@ class PretrainedConfig(object): ...@@ -467,16 +467,16 @@ class PretrainedConfig(object):
except json.JSONDecodeError: except json.JSONDecodeError:
msg = ( msg = (
"Couldn't reach server at '{}' to download configuration file or " f"Couldn't reach server at '{config_file}' to download configuration file or "
"configuration file is not a valid JSON file. " "configuration file is not a valid JSON file. "
"Please check network or file content here: {}.".format(config_file, resolved_config_file) f"Please check network or file content here: {resolved_config_file}."
) )
raise EnvironmentError(msg) raise EnvironmentError(msg)
if resolved_config_file == config_file: if resolved_config_file == config_file:
logger.info("loading configuration file {}".format(config_file)) logger.info(f"loading configuration file {config_file}")
else: else:
logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file)) logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
return config_dict, kwargs return config_dict, kwargs
...@@ -512,7 +512,7 @@ class PretrainedConfig(object): ...@@ -512,7 +512,7 @@ class PretrainedConfig(object):
for key in to_remove: for key in to_remove:
kwargs.pop(key, None) kwargs.pop(key, None)
logger.info("Model config %s", str(config)) logger.info(f"Model config {config}")
if return_unused_kwargs: if return_unused_kwargs:
return config, kwargs return config, kwargs
else: else:
...@@ -544,7 +544,7 @@ class PretrainedConfig(object): ...@@ -544,7 +544,7 @@ class PretrainedConfig(object):
return self.__dict__ == other.__dict__ return self.__dict__ == other.__dict__
def __repr__(self): def __repr__(self):
return "{} {}".format(self.__class__.__name__, self.to_json_string()) return f"{self.__class__.__name__} {self.to_json_string()}"
def to_diff_dict(self) -> Dict[str, Any]: def to_diff_dict(self) -> Dict[str, Any]:
""" """
......
...@@ -154,7 +154,7 @@ def ensure_valid_input(model, tokens, input_names): ...@@ -154,7 +154,7 @@ def ensure_valid_input(model, tokens, input_names):
print(f"{arg_name} is not present in the generated input list.") print(f"{arg_name} is not present in the generated input list.")
break break
print("Generated inputs order: {}".format(ordered_input_names)) print(f"Generated inputs order: {ordered_input_names}")
return ordered_input_names, tuple(model_args) return ordered_input_names, tuple(model_args)
......
...@@ -294,7 +294,7 @@ def convert_pt_checkpoint_to_tf( ...@@ -294,7 +294,7 @@ def convert_pt_checkpoint_to_tf(
model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True
): ):
if model_type not in MODEL_CLASSES: if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys()))) raise ValueError(f"Unrecognized model type, should be one of {list(MODEL_CLASSES.keys())}.")
config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type] config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type]
...@@ -304,7 +304,7 @@ def convert_pt_checkpoint_to_tf( ...@@ -304,7 +304,7 @@ def convert_pt_checkpoint_to_tf(
config = config_class.from_json_file(config_file) config = config_class.from_json_file(config_file)
config.output_hidden_states = True config.output_hidden_states = True
config.output_attentions = True config.output_attentions = True
print("Building TensorFlow model from configuration: {}".format(str(config))) print(f"Building TensorFlow model from configuration: {config}")
tf_model = model_class(config) tf_model = model_class(config)
# Load weights from tf checkpoint # Load weights from tf checkpoint
...@@ -328,11 +328,11 @@ def convert_pt_checkpoint_to_tf( ...@@ -328,11 +328,11 @@ def convert_pt_checkpoint_to_tf(
np_pt = pto[0].numpy() np_pt = pto[0].numpy()
np_tf = tfo[0].numpy() np_tf = tfo[0].numpy()
diff = np.amax(np.abs(np_pt - np_tf)) diff = np.amax(np.abs(np_pt - np_tf))
print("Max absolute difference between models outputs {}".format(diff)) print(f"Max absolute difference between models outputs {diff}")
assert diff <= 2e-2, "Error, model absolute difference is >2e-2: {}".format(diff) assert diff <= 2e-2, f"Error, model absolute difference is >2e-2: {diff}"
# Save pytorch-model # Save pytorch-model
print("Save TensorFlow model to {}".format(tf_dump_path)) print(f"Save TensorFlow model to {tf_dump_path}")
tf_model.save_weights(tf_dump_path, save_format="h5") tf_model.save_weights(tf_dump_path, save_format="h5")
...@@ -354,12 +354,10 @@ def convert_all_pt_checkpoints_to_tf( ...@@ -354,12 +354,10 @@ def convert_all_pt_checkpoints_to_tf(
for j, model_type in enumerate(model_types, start=1): for j, model_type in enumerate(model_types, start=1):
print("=" * 100) print("=" * 100)
print(" Converting model type {}/{}: {}".format(j, len(model_types), model_type)) print(f" Converting model type {j}/{len(model_types)}: {model_type}")
print("=" * 100) print("=" * 100)
if model_type not in MODEL_CLASSES: if model_type not in MODEL_CLASSES:
raise ValueError( raise ValueError(f"Unrecognized model type {model_type}, should be one of {list(MODEL_CLASSES.keys())}.")
"Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys()))
)
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
...@@ -374,16 +372,14 @@ def convert_all_pt_checkpoints_to_tf( ...@@ -374,16 +372,14 @@ def convert_all_pt_checkpoints_to_tf(
print("-" * 100) print("-" * 100)
if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name: if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:
if not only_convert_finetuned_models: if not only_convert_finetuned_models:
print(" Skipping finetuned checkpoint {}".format(model_shortcut_name)) print(f" Skipping finetuned checkpoint {model_shortcut_name}")
continue continue
model_type = model_shortcut_name model_type = model_shortcut_name
elif only_convert_finetuned_models: elif only_convert_finetuned_models:
print(" Skipping not finetuned checkpoint {}".format(model_shortcut_name)) print(f" Skipping not finetuned checkpoint {model_shortcut_name}")
continue continue
print( print(
" Converting checkpoint {}/{}: {} - model_type {}".format( f" Converting checkpoint {i}/{len(aws_config_map)}: {model_shortcut_name} - model_type {model_type}"
i, len(aws_config_map), model_shortcut_name, model_type
)
) )
print("-" * 100) print("-" * 100)
...@@ -422,9 +418,8 @@ if __name__ == "__main__": ...@@ -422,9 +418,8 @@ if __name__ == "__main__":
"--model_type", "--model_type",
default=None, default=None,
type=str, type=str,
help="Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format( help=f"Model type selected in the list of {list(MODEL_CLASSES.keys())}. If not given, will download and "
list(MODEL_CLASSES.keys()) "convert all the models from AWS.",
),
) )
parser.add_argument( parser.add_argument(
"--pytorch_checkpoint_path", "--pytorch_checkpoint_path",
......
...@@ -633,7 +633,7 @@ class T5Converter(SpmConverter): ...@@ -633,7 +633,7 @@ class T5Converter(SpmConverter):
def vocab(self, proto): def vocab(self, proto):
num_extra_ids = self.original_tokenizer._extra_ids num_extra_ids = self.original_tokenizer._extra_ids
vocab = [(piece.piece, piece.score) for piece in proto.pieces] vocab = [(piece.piece, piece.score) for piece in proto.pieces]
vocab += [("<extra_id_{}>".format(i), 0.0) for i in range(num_extra_ids - 1, -1, -1)] vocab += [(f"<extra_id_{i}>", 0.0) for i in range(num_extra_ids - 1, -1, -1)]
return vocab return vocab
def post_processor(self): def post_processor(self):
......
...@@ -33,7 +33,7 @@ TOKENIZER_CLASSES = {name: getattr(transformers, name + "Fast") for name in SLOW ...@@ -33,7 +33,7 @@ TOKENIZER_CLASSES = {name: getattr(transformers, name + "Fast") for name in SLOW
def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path, force_download): def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path, force_download):
if tokenizer_name is not None and tokenizer_name not in TOKENIZER_CLASSES: if tokenizer_name is not None and tokenizer_name not in TOKENIZER_CLASSES:
raise ValueError("Unrecognized tokenizer name, should be one of {}.".format(list(TOKENIZER_CLASSES.keys()))) raise ValueError(f"Unrecognized tokenizer name, should be one of {list(TOKENIZER_CLASSES.keys())}.")
if tokenizer_name is None: if tokenizer_name is None:
tokenizer_names = TOKENIZER_CLASSES tokenizer_names = TOKENIZER_CLASSES
...@@ -60,9 +60,7 @@ def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path, ...@@ -60,9 +60,7 @@ def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path,
tokenizer = tokenizer_class.from_pretrained(checkpoint, force_download=force_download) tokenizer = tokenizer_class.from_pretrained(checkpoint, force_download=force_download)
# Save fast tokenizer # Save fast tokenizer
logger.info( logger.info(f"Save fast tokenizer to {dump_path} with prefix {checkpoint} add_prefix {add_prefix}")
"Save fast tokenizer to {} with prefix {} add_prefix {}".format(dump_path, checkpoint, add_prefix)
)
# For organization names we create sub-directories # For organization names we create sub-directories
if "/" in checkpoint: if "/" in checkpoint:
...@@ -75,9 +73,7 @@ def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path, ...@@ -75,9 +73,7 @@ def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path,
checkpoint_prefix_name = None checkpoint_prefix_name = None
dump_path_full = dump_path dump_path_full = dump_path
logger.info( logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}")
"=> {} with prefix {}, add_prefix {}".format(dump_path_full, checkpoint_prefix_name, add_prefix)
)
if checkpoint in list(tokenizer.pretrained_vocab_files_map.values())[0]: if checkpoint in list(tokenizer.pretrained_vocab_files_map.values())[0]:
file_path = list(tokenizer.pretrained_vocab_files_map.values())[0][checkpoint] file_path = list(tokenizer.pretrained_vocab_files_map.values())[0][checkpoint]
...@@ -86,19 +82,17 @@ def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path, ...@@ -86,19 +82,17 @@ def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path,
dump_path_full = os.path.join(dump_path_full, checkpoint_prefix_name) dump_path_full = os.path.join(dump_path_full, checkpoint_prefix_name)
checkpoint_prefix_name = None checkpoint_prefix_name = None
logger.info( logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}")
"=> {} with prefix {}, add_prefix {}".format(dump_path_full, checkpoint_prefix_name, add_prefix)
)
file_names = tokenizer.save_pretrained( file_names = tokenizer.save_pretrained(
dump_path_full, legacy_format=False, filename_prefix=checkpoint_prefix_name dump_path_full, legacy_format=False, filename_prefix=checkpoint_prefix_name
) )
logger.info("=> File names {}".format(file_names)) logger.info(f"=> File names {file_names}")
for file_name in file_names: for file_name in file_names:
if not file_name.endswith("tokenizer.json"): if not file_name.endswith("tokenizer.json"):
os.remove(file_name) os.remove(file_name)
logger.info("=> removing {}".format(file_name)) logger.info(f"=> removing {file_name}")
if __name__ == "__main__": if __name__ == "__main__":
...@@ -111,9 +105,8 @@ if __name__ == "__main__": ...@@ -111,9 +105,8 @@ if __name__ == "__main__":
"--tokenizer_name", "--tokenizer_name",
default=None, default=None,
type=str, type=str,
help="Optional tokenizer type selected in the list of {}. If not given, will download and convert all the checkpoints from AWS.".format( help=f"Optional tokenizer type selected in the list of {list(TOKENIZER_CLASSES.keys())}. If not given, will "
list(TOKENIZER_CLASSES.keys()) "download and convert all the checkpoints from AWS.",
),
) )
parser.add_argument( parser.add_argument(
"--checkpoint_name", "--checkpoint_name",
......
...@@ -46,7 +46,7 @@ def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_ ...@@ -46,7 +46,7 @@ def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_
model = BertGenerationEncoder(config) model = BertGenerationEncoder(config)
else: else:
model = BertGenerationDecoder(config) model = BertGenerationDecoder(config)
print("Building PyTorch model from configuration: {}".format(str(config))) print(f"Building PyTorch model from configuration: {config}")
# Load weights from tf checkpoint # Load weights from tf checkpoint
load_tf_weights_in_bert_generation( load_tf_weights_in_bert_generation(
...@@ -58,7 +58,7 @@ def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_ ...@@ -58,7 +58,7 @@ def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_
) )
# Save pytorch-model # Save pytorch-model
print("Save PyTorch model and config to {}".format(pytorch_dump_path)) print(f"Save PyTorch model and config to {pytorch_dump_path}")
model.save_pretrained(pytorch_dump_path) model.save_pretrained(pytorch_dump_path)
......
...@@ -101,12 +101,7 @@ class GlueDataset(Dataset): ...@@ -101,12 +101,7 @@ class GlueDataset(Dataset):
# Load data features from cache or dataset file # Load data features from cache or dataset file
cached_features_file = os.path.join( cached_features_file = os.path.join(
cache_dir if cache_dir is not None else args.data_dir, cache_dir if cache_dir is not None else args.data_dir,
"cached_{}_{}_{}_{}".format( f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{args.task_name}",
mode.value,
tokenizer.__class__.__name__,
str(args.max_seq_length),
args.task_name,
),
) )
label_list = self.processor.get_labels() label_list = self.processor.get_labels()
if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__.__name__ in ( if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__.__name__ in (
...@@ -153,7 +148,7 @@ class GlueDataset(Dataset): ...@@ -153,7 +148,7 @@ class GlueDataset(Dataset):
torch.save(self.features, cached_features_file) torch.save(self.features, cached_features_file)
# ^ This seems to take a lot of time so I want to investigate why and how we can improve. # ^ This seems to take a lot of time so I want to investigate why and how we can improve.
logger.info( logger.info(
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
) )
def __len__(self): def __len__(self):
......
...@@ -64,11 +64,7 @@ class TextDataset(Dataset): ...@@ -64,11 +64,7 @@ class TextDataset(Dataset):
directory, filename = os.path.split(file_path) directory, filename = os.path.split(file_path)
cached_features_file = os.path.join( cached_features_file = os.path.join(
cache_dir if cache_dir is not None else directory, cache_dir if cache_dir is not None else directory,
"cached_lm_{}_{}_{}".format( f"cached_lm_{tokenizer.__class__.__name__}_{block_size}_{filename}",
tokenizer.__class__.__name__,
str(block_size),
filename,
),
) )
# Make sure only the first process in distributed training processes the dataset, # Make sure only the first process in distributed training processes the dataset,
...@@ -105,7 +101,7 @@ class TextDataset(Dataset): ...@@ -105,7 +101,7 @@ class TextDataset(Dataset):
with open(cached_features_file, "wb") as handle: with open(cached_features_file, "wb") as handle:
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
logger.info( logger.info(
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
) )
def __len__(self): def __len__(self):
...@@ -131,7 +127,7 @@ class LineByLineTextDataset(Dataset): ...@@ -131,7 +127,7 @@ class LineByLineTextDataset(Dataset):
# Here, we do not cache the features, operating under the assumption # Here, we do not cache the features, operating under the assumption
# that we will soon use fast multithreaded tokenizers from the # that we will soon use fast multithreaded tokenizers from the
# `tokenizers` repo everywhere =) # `tokenizers` repo everywhere =)
logger.info("Creating features from dataset file at %s", file_path) logger.info(f"Creating features from dataset file at {file_path}")
with open(file_path, encoding="utf-8") as f: with open(file_path, encoding="utf-8") as f:
lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())] lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
...@@ -164,8 +160,8 @@ class LineByLineWithRefDataset(Dataset): ...@@ -164,8 +160,8 @@ class LineByLineWithRefDataset(Dataset):
# Here, we do not cache the features, operating under the assumption # Here, we do not cache the features, operating under the assumption
# that we will soon use fast multithreaded tokenizers from the # that we will soon use fast multithreaded tokenizers from the
# `tokenizers` repo everywhere =) # `tokenizers` repo everywhere =)
logger.info("Creating features from dataset file at %s", file_path) logger.info(f"Creating features from dataset file at {file_path}")
logger.info("Use ref segment results at %s", ref_path) logger.info(f"Use ref segment results at {ref_path}")
with open(file_path, encoding="utf-8") as f: with open(file_path, encoding="utf-8") as f:
data = f.readlines() # use this method to avoid delimiter '\u2029' to split a line data = f.readlines() # use this method to avoid delimiter '\u2029' to split a line
data = [line.strip() for line in data if len(line) > 0 and not line.isspace()] data = [line.strip() for line in data if len(line) > 0 and not line.isspace()]
...@@ -365,11 +361,7 @@ class TextDatasetForNextSentencePrediction(Dataset): ...@@ -365,11 +361,7 @@ class TextDatasetForNextSentencePrediction(Dataset):
directory, filename = os.path.split(file_path) directory, filename = os.path.split(file_path)
cached_features_file = os.path.join( cached_features_file = os.path.join(
directory, directory,
"cached_nsp_{}_{}_{}".format( f"cached_nsp_{tokenizer.__class__.__name__}_{block_size}_{filename}",
tokenizer.__class__.__name__,
str(block_size),
filename,
),
) )
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -427,7 +419,7 @@ class TextDatasetForNextSentencePrediction(Dataset): ...@@ -427,7 +419,7 @@ class TextDatasetForNextSentencePrediction(Dataset):
with open(cached_features_file, "wb") as handle: with open(cached_features_file, "wb") as handle:
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
logger.info( logger.info(
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
) )
def create_examples_from_document(self, document: List[List[int]], doc_index: int): def create_examples_from_document(self, document: List[List[int]], doc_index: int):
......
...@@ -131,12 +131,7 @@ class SquadDataset(Dataset): ...@@ -131,12 +131,7 @@ class SquadDataset(Dataset):
version_tag = "v2" if args.version_2_with_negative else "v1" version_tag = "v2" if args.version_2_with_negative else "v1"
cached_features_file = os.path.join( cached_features_file = os.path.join(
cache_dir if cache_dir is not None else args.data_dir, cache_dir if cache_dir is not None else args.data_dir,
"cached_{}_{}_{}_{}".format( f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{version_tag}",
mode.value,
tokenizer.__class__.__name__,
str(args.max_seq_length),
version_tag,
),
) )
# Make sure only the first process in distributed training processes the dataset, # Make sure only the first process in distributed training processes the dataset,
...@@ -184,7 +179,7 @@ class SquadDataset(Dataset): ...@@ -184,7 +179,7 @@ class SquadDataset(Dataset):
) )
# ^ This seems to take a lot of time so I want to investigate why and how we can improve. # ^ This seems to take a lot of time so I want to investigate why and how we can improve.
logger.info( logger.info(
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
) )
def __len__(self): def __len__(self):
......
...@@ -96,7 +96,7 @@ def get_raw_scores(examples, preds): ...@@ -96,7 +96,7 @@ def get_raw_scores(examples, preds):
gold_answers = [""] gold_answers = [""]
if qas_id not in preds: if qas_id not in preds:
print("Missing prediction for %s" % qas_id) print(f"Missing prediction for {qas_id}")
continue continue
prediction = preds[qas_id] prediction = preds[qas_id]
...@@ -140,7 +140,7 @@ def make_eval_dict(exact_scores, f1_scores, qid_list=None): ...@@ -140,7 +140,7 @@ def make_eval_dict(exact_scores, f1_scores, qid_list=None):
def merge_eval(main_eval, new_eval, prefix): def merge_eval(main_eval, new_eval, prefix):
for k in new_eval: for k in new_eval:
main_eval["%s_%s" % (prefix, k)] = new_eval[k] main_eval[f"{prefix}_{k}"] = new_eval[k]
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans): def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
...@@ -302,7 +302,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): ...@@ -302,7 +302,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
start_position = tok_text.find(pred_text) start_position = tok_text.find(pred_text)
if start_position == -1: if start_position == -1:
if verbose_logging: if verbose_logging:
logger.info("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) logger.info(f"Unable to find text: '{pred_text}' in '{orig_text}'")
return orig_text return orig_text
end_position = start_position + len(pred_text) - 1 end_position = start_position + len(pred_text) - 1
...@@ -311,7 +311,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): ...@@ -311,7 +311,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
if len(orig_ns_text) != len(tok_ns_text): if len(orig_ns_text) != len(tok_ns_text):
if verbose_logging: if verbose_logging:
logger.info("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text) logger.info(f"Length not equal after stripping spaces: '{orig_ns_text}' vs '{tok_ns_text}'")
return orig_text return orig_text
# We then project the characters in `pred_text` back to `orig_text` using # We then project the characters in `pred_text` back to `orig_text` using
...@@ -615,8 +615,7 @@ def compute_predictions_log_probs( ...@@ -615,8 +615,7 @@ def compute_predictions_log_probs(
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"] "NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
) )
logger.info("Writing predictions to: %s", output_prediction_file) logger.info(f"Writing predictions to: {output_prediction_file}")
# logger.info("Writing nbest to: %s" % (output_nbest_file))
example_index_to_features = collections.defaultdict(list) example_index_to_features = collections.defaultdict(list)
for feature in all_features: for feature in all_features:
......
...@@ -122,10 +122,10 @@ def _glue_convert_examples_to_features( ...@@ -122,10 +122,10 @@ def _glue_convert_examples_to_features(
processor = glue_processors[task]() processor = glue_processors[task]()
if label_list is None: if label_list is None:
label_list = processor.get_labels() label_list = processor.get_labels()
logger.info("Using label list %s for task %s" % (label_list, task)) logger.info(f"Using label list {label_list} for task {task}")
if output_mode is None: if output_mode is None:
output_mode = glue_output_modes[task] output_mode = glue_output_modes[task]
logger.info("Using output mode %s for task %s" % (output_mode, task)) logger.info(f"Using output mode {output_mode} for task {task}")
label_map = {label: i for i, label in enumerate(label_list)} label_map = {label: i for i, label in enumerate(label_list)}
...@@ -156,8 +156,8 @@ def _glue_convert_examples_to_features( ...@@ -156,8 +156,8 @@ def _glue_convert_examples_to_features(
for i, example in enumerate(examples[:5]): for i, example in enumerate(examples[:5]):
logger.info("*** Example ***") logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid)) logger.info(f"guid: {example.guid}")
logger.info("features: %s" % features[i]) logger.info(f"features: {features[i]}")
return features return features
...@@ -185,7 +185,7 @@ class MrpcProcessor(DataProcessor): ...@@ -185,7 +185,7 @@ class MrpcProcessor(DataProcessor):
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) logger.info(f"LOOKING AT {os.path.join(data_dir, 'train.tsv')}")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
...@@ -206,7 +206,7 @@ class MrpcProcessor(DataProcessor): ...@@ -206,7 +206,7 @@ class MrpcProcessor(DataProcessor):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = f"{set_type}-{i}"
text_a = line[3] text_a = line[3]
text_b = line[4] text_b = line[4]
label = None if set_type == "test" else line[0] label = None if set_type == "test" else line[0]
...@@ -252,7 +252,7 @@ class MnliProcessor(DataProcessor): ...@@ -252,7 +252,7 @@ class MnliProcessor(DataProcessor):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, line[0]) guid = f"{set_type}-{line[0]}"
text_a = line[8] text_a = line[8]
text_b = line[9] text_b = line[9]
label = None if set_type.startswith("test") else line[-1] label = None if set_type.startswith("test") else line[-1]
...@@ -316,7 +316,7 @@ class ColaProcessor(DataProcessor): ...@@ -316,7 +316,7 @@ class ColaProcessor(DataProcessor):
text_index = 1 if test_mode else 3 text_index = 1 if test_mode else 3
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i) guid = f"{set_type}-{i}"
text_a = line[text_index] text_a = line[text_index]
label = None if test_mode else line[1] label = None if test_mode else line[1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
...@@ -362,7 +362,7 @@ class Sst2Processor(DataProcessor): ...@@ -362,7 +362,7 @@ class Sst2Processor(DataProcessor):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = f"{set_type}-{i}"
text_a = line[text_index] text_a = line[text_index]
label = None if set_type == "test" else line[1] label = None if set_type == "test" else line[1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
...@@ -407,7 +407,7 @@ class StsbProcessor(DataProcessor): ...@@ -407,7 +407,7 @@ class StsbProcessor(DataProcessor):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, line[0]) guid = f"{set_type}-{line[0]}"
text_a = line[7] text_a = line[7]
text_b = line[8] text_b = line[8]
label = None if set_type == "test" else line[-1] label = None if set_type == "test" else line[-1]
...@@ -456,7 +456,7 @@ class QqpProcessor(DataProcessor): ...@@ -456,7 +456,7 @@ class QqpProcessor(DataProcessor):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, line[0]) guid = f"{set_type}-{line[0]}"
try: try:
text_a = line[q1_index] text_a = line[q1_index]
text_b = line[q2_index] text_b = line[q2_index]
...@@ -505,7 +505,7 @@ class QnliProcessor(DataProcessor): ...@@ -505,7 +505,7 @@ class QnliProcessor(DataProcessor):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, line[0]) guid = f"{set_type}-{line[0]}"
text_a = line[1] text_a = line[1]
text_b = line[2] text_b = line[2]
label = None if set_type == "test" else line[-1] label = None if set_type == "test" else line[-1]
...@@ -551,7 +551,7 @@ class RteProcessor(DataProcessor): ...@@ -551,7 +551,7 @@ class RteProcessor(DataProcessor):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, line[0]) guid = f"{set_type}-{line[0]}"
text_a = line[1] text_a = line[1]
text_b = line[2] text_b = line[2]
label = None if set_type == "test" else line[-1] label = None if set_type == "test" else line[-1]
...@@ -597,7 +597,7 @@ class WnliProcessor(DataProcessor): ...@@ -597,7 +597,7 @@ class WnliProcessor(DataProcessor):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, line[0]) guid = f"{set_type}-{line[0]}"
text_a = line[1] text_a = line[1]
text_b = line[2] text_b = line[2]
label = None if set_type == "test" else line[-1] label = None if set_type == "test" else line[-1]
......
...@@ -115,7 +115,7 @@ def squad_convert_example_to_features( ...@@ -115,7 +115,7 @@ def squad_convert_example_to_features(
actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)]) actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text)) cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
if actual_text.find(cleaned_answer_text) == -1: if actual_text.find(cleaned_answer_text) == -1:
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text) logger.warning(f"Could not find answer: '{actual_text}' vs. '{cleaned_answer_text}'")
return [] return []
tok_to_orig_index = [] tok_to_orig_index = []
......
...@@ -186,7 +186,7 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -186,7 +186,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
if column_id is not None: if column_id is not None:
ids.append(line[column_id]) ids.append(line[column_id])
else: else:
guid = "%s-%s" % (split_name, i) if split_name else "%s" % i guid = f"{split_name}-{i}" if split_name else str(i)
ids.append(guid) ids.append(guid)
return self.add_examples( return self.add_examples(
...@@ -265,7 +265,7 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -265,7 +265,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
all_input_ids = [] all_input_ids = []
for (ex_index, example) in enumerate(self.examples): for (ex_index, example) in enumerate(self.examples):
if ex_index % 10000 == 0: if ex_index % 10000 == 0:
logger.info("Tokenizing example %d", ex_index) logger.info(f"Tokenizing example {ex_index}")
input_ids = tokenizer.encode( input_ids = tokenizer.encode(
example.text_a, example.text_a,
...@@ -279,7 +279,7 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -279,7 +279,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
features = [] features = []
for (ex_index, (input_ids, example)) in enumerate(zip(all_input_ids, self.examples)): for (ex_index, (input_ids, example)) in enumerate(zip(all_input_ids, self.examples)):
if ex_index % 10000 == 0: if ex_index % 10000 == 0:
logger.info("Writing example %d/%d" % (ex_index, len(self.examples))) logger.info(f"Writing example {ex_index}/{len(self.examples)}")
# The mask has 1 for real tokens and 0 for padding tokens. Only real # The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to. # tokens are attended to.
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
...@@ -293,12 +293,10 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -293,12 +293,10 @@ class SingleSentenceClassificationProcessor(DataProcessor):
input_ids = input_ids + ([pad_token] * padding_length) input_ids = input_ids + ([pad_token] * padding_length)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
assert len(input_ids) == batch_length, "Error with input length {} vs {}".format( assert len(input_ids) == batch_length, f"Error with input length {len(input_ids)} vs {batch_length}"
len(input_ids), batch_length assert (
) len(attention_mask) == batch_length
assert len(attention_mask) == batch_length, "Error with input length {} vs {}".format( ), f"Error with input length {len(attention_mask)} vs {batch_length}"
len(attention_mask), batch_length
)
if self.mode == "classification": if self.mode == "classification":
label = label_map[example.label] label = label_map[example.label]
...@@ -309,10 +307,10 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -309,10 +307,10 @@ class SingleSentenceClassificationProcessor(DataProcessor):
if ex_index < 5 and self.verbose: if ex_index < 5 and self.verbose:
logger.info("*** Example ***") logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid)) logger.info(f"guid: {example.guid}")
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) logger.info(f"input_ids: {' '.join([str(x) for x in input_ids])}")
logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask])) logger.info(f"attention_mask: {' '.join([str(x) for x in attention_mask])}")
logger.info("label: %s (id = %d)" % (example.label, label)) logger.info(f"label: {example.label} (id = {label})")
features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label)) features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment