Unverified Commit 667ccea7 authored by Katie Le's avatar Katie Le Committed by GitHub
Browse files

Replace assertion with ValueError exceptions in run_image_captioning_flax.py (#20365)



* replace 4 asserts with ValueError exception for control flow

* Update examples/flax/image-captioning/run_image_captioning_flax.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update examples/flax/image-captioning/run_image_captioning_flax.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* reformatted file

* uninstalled trasformers and applied make style
Co-authored-by: default avatarBibi <Bibi@katies-mac.local>
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent 0a619325
...@@ -298,10 +298,12 @@ class DataTrainingArguments: ...@@ -298,10 +298,12 @@ class DataTrainingArguments:
else: else:
if self.train_file is not None: if self.train_file is not None:
extension = self.train_file.split(".")[-1] extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." if extension not in ["csv", "json"]:
raise ValueError(f"`train_file` should be a csv or a json file, got {extension}.")
if self.validation_file is not None: if self.validation_file is not None:
extension = self.validation_file.split(".")[-1] extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." if extension not in ["csv", "json"]:
raise ValueError(f"`validation_file` should be a csv or a json file, got {extension}.")
if self.val_max_target_length is None: if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length self.val_max_target_length = self.max_target_length
...@@ -502,7 +504,12 @@ def main(): ...@@ -502,7 +504,12 @@ def main():
# Get the column names for input/target. # Get the column names for input/target.
dataset_columns = image_captioning_name_mapping.get(data_args.dataset_name, None) dataset_columns = image_captioning_name_mapping.get(data_args.dataset_name, None)
if data_args.image_column is None: if data_args.image_column is None:
assert dataset_columns is not None if dataset_columns is None:
raise ValueError(
f"`--dataset_name` {data_args.dataset_name} not found in dataset '{data_args.dataset_name}'. Make sure"
" to set `--dataset_name` to the correct dataset name, one of"
f" {', '.join(image_captioning_name_mapping.keys())}."
)
image_column = dataset_columns[0] image_column = dataset_columns[0]
else: else:
image_column = data_args.image_column image_column = data_args.image_column
...@@ -511,7 +518,12 @@ def main(): ...@@ -511,7 +518,12 @@ def main():
f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}" f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}"
) )
if data_args.caption_column is None: if data_args.caption_column is None:
assert dataset_columns is not None if dataset_columns is None:
raise ValueError(
f"`--dataset_name` {data_args.dataset_name} not found in dataset '{data_args.dataset_name}'. Make sure"
" to set `--dataset_name` to the correct dataset name, one of"
f" {', '.join(image_captioning_name_mapping.keys())}."
)
caption_column = dataset_columns[1] caption_column = dataset_columns[1]
else: else:
caption_column = data_args.caption_column caption_column = data_args.caption_column
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment