Support fp16 using tf.keras.mixed_precision in CTL resnet.
To test, I ran the following command: python resnet_ctl_imagenet_main.py --batch_size=2048 --data_dir ~/imagenet --datasets_num_private_threads=14 --epochs_between_evals=10 --model_dir ~/tmp_model_dir --clean --num_gpus=8 --train_epochs=90 --dtype=fp16 I got 76.15% final evaluation accuracy. PiperOrigin-RevId: 278010061
Showing
Please register or sign in to comment