Unverified Commit acb709d5 authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Change no trainer image_classification test (#17635)

* Adjust test arguments and use a new example test
parent e70abdad
......@@ -300,19 +300,25 @@ class ExamplesTestsNoTrainer(TestCasePlus):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
{self.examples_dir}/pytorch/image-classification/run_image_classification_no_trainer.py
--dataset_name huggingface/image-classification-test-sample
--model_name_or_path google/vit-base-patch16-224-in21k
--dataset_name hf-internal-testing/cats_vs_dogs_sample
--learning_rate 1e-4
--per_device_train_batch_size 2
--per_device_eval_batch_size 1
--max_train_steps 2
--train_val_split 0.1
--seed 42
--output_dir {tmp_dir}
--num_warmup_steps=8
--learning_rate=3e-3
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--checkpointing_steps epoch
--with_tracking
--seed 42
--checkpointing_steps 1
""".split()
if is_cuda_and_apex_available():
testargs.append("--fp16")
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.50)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
# The base model scores a 25%
self.assertGreaterEqual(result["eval_accuracy"], 0.625)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "step_1")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "image_classification_no_trainer")))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment