# Copyright 2024 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import tempfile from absl.testing import absltest from absl.testing import parameterized import ml_collections import tensorflow_datasets as tfds from vit_jax import test_utils from vit_jax import train from vit_jax.configs import common from vit_jax.configs import models # from PIL import Image # import numpy as np # Image.fromarray(np.array([[[0, 0, 0]]], np.uint8)).save('black1px.jpg') # print(repr(file('black1px.jpg', 'rb').read())) JPG_BLACK_1PX = (b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c' b' $.\' ' b'",#\x1c\x1c(7),01444\x1f\'9=82<.342\xff\xdb\x00C\x01\t\t\t\x0c\x0b\x0c\x18\r\r\x182!\x1c!22222222222222222222222222222222222222222222222222\xff\xc0\x00\x11\x08\x00\x01\x00\x01\x03\x01"\x00\x02\x11\x01\x03\x11\x01\xff\xc4\x00\x1f\x00\x00\x01\x05\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\xff\xc4\x00\xb5\x10\x00\x02\x01\x03\x03\x02\x04\x03\x05\x05\x04\x04\x00\x00\x01}\x01\x02\x03\x00\x04\x11\x05\x12!1A\x06\x13Qa\x07"q\x142\x81\x91\xa1\x08#B\xb1\xc1\x15R\xd1\xf0$3br\x82\t\n\x16\x17\x18\x19\x1a%&\'()*456789:CDEFGHIJSTUVWXYZcdefghijstuvwxyz\x83\x84\x85\x86\x87\x88\x89\x8a\x92\x93\x94\x95\x96\x97\x98\x99\x9a\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xff\xc4\x00\x1f\x01\x00\x03\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\xff\xc4\x00\xb5\x11\x00\x02\x01\x02\x04\x04\x03\x04\x07\x05\x04\x04\x00\x01\x02w\x00\x01\x02\x03\x11\x04\x05!1\x06\x12AQ\x07aq\x13"2\x81\x08\x14B\x91\xa1\xb1\xc1\t#3R\xf0\x15br\xd1\n\x16$4\xe1%\xf1\x17\x18\x19\x1a&\'()*56789:CDEFGHIJSTUVWXYZcdefghijstuvwxyz\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x92\x93\x94\x95\x96\x97\x98\x99\x9a\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xff\xda\x00\x0c\x03\x01\x00\x02\x11\x03\x11\x00?\x00\xf9\xfe\x8a(\xa0\x0f\xff\xd9') # pylint: disable=line-too-long class TrainTest(parameterized.TestCase): @parameterized.named_parameters( ('tfds', 'tfds'), ('directory', 'directory'), ) def test_train_and_evaluate(self, dataset_source): config = common.get_config() config.model = models.get_testing_config() config.batch = 64 config.accum_steps = 2 config.batch_eval = 8 config.total_steps = 1 with tempfile.TemporaryDirectory() as workdir: if dataset_source == 'tfds': config.dataset = 'cifar10' config.pp = ml_collections.ConfigDict({ 'train': 'train[:98%]', 'test': 'test', 'crop': 224 }) elif dataset_source == 'directory': config.dataset = os.path.join(workdir, 'dataset') config.pp = ml_collections.ConfigDict({'crop': 224}) for mode in ('train', 'test'): for class_name in ('test1', 'test2'): for i in range(8): path = os.path.join(config.dataset, mode, class_name, f'{i}.jpg') os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, 'wb') as f: f.write(JPG_BLACK_1PX) else: raise ValueError(f'Unknown dataset_source: "{dataset_source}"') config.pretrained_dir = workdir test_utils.create_checkpoint(config.model, f'{workdir}/testing.npz') _ = train.train_and_evaluate(config, workdir) self.assertTrue(os.path.exists(f'{workdir}/checkpoint_1')) if __name__ == '__main__': absltest.main()