#!/bin/bash

# 检查是否提供了输入参数
if [ -z "$1" ]; then
  echo "please set input image"
  exit 1
fi

# 检查第一个输入参数中是否包含"pytorch"字符串
if [[ "$1" == *"pytorch"* ]]; then
  docker run --rm --platform=linux/amd64 --gpus all $1  python -c \
      "import os; \
      os.system(\"cat /etc/issue\"); \
      import sys; \
      print(\"python version: \", sys.version); \
      import torch; \
      print(\"torch version: \", torch.__version__); \
      print(\"torch cuda available: \", torch.cuda.is_available()); \
      print(\"torch cuda version: \", torch.version.cuda); \
      print(\"torch cudnn version: \",torch.backends.cudnn.version()); \
      import torchvision; \
      print(\"torchvision version: \", torchvision.__version__); \
      import torchaudio; \
      print(\"torchaudio version: \", torchaudio.__version__);
      "
elif [[ "$1" == *"tensorflow"* ]]; then
  tensorflow_version=$(echo "$1" | cut -d: -f2 | cut -d- -f1)
  # 当tensorflow版本为2.16.1时，不添加环境变量找不到cuda，所以需要这样执行验证。在正常交互式启动容器时，会默认激活/etc/bash.bashrc，可以正常找到cuda
  if [[ "$tensorflow_version" == "2.16.1" ]]; then
  python_version=$(echo $1 | awk -F'[-:]' '{for(i=3;i<=NF;i++) if($i ~ /^py[0-9]+\.[0-9]+$/) {gsub(/^py/,"",$i); print $i; exit}}')
  docker run --rm --platform=linux/amd64 --gpus all \
  -e CUDNN_PATH="/opt/conda/lib/python$python_version/site-packages/nvidia/cudnn" \
  -e LD_LIBRARY_PATH="/opt/conda/lib/python$python_version/site-packages/nvidia/cudnn/lib:/usr/local/cuda/lib64" \
  $1 python -c "import os; \
                os.system(\"cat /etc/issue\"); \
                import sys; \
                print(\"python version: \", sys.version); \
                import tensorflow as tf; \
                print(\"tensorflow version: \", tf.__version__); \
                print(\"tensorflow cuda available: \", tf.test.is_gpu_available()); \
                os.system('nvcc -V | tail -n 2')
                ";
  else docker run --rm --platform=linux/amd64 --gpus all $1  python -c \
      "import os; \
      os.system(\"cat /etc/issue\"); \
      import sys; \
      print(\"python version: \", sys.version); \
      import tensorflow as tf; \
      print(\"tensorflow version: \", tf.__version__); \
      print(\"tensorflow cuda available: \", tf.test.is_gpu_available()); \
      os.system('nvcc -V | tail -n 2')
      "; fi
elif [[ "$1" == *"paddle"* ]]; then
  TARGET_DIR=gpu-base-image-test/paddletest
  docker run --rm --platform=linux/amd64 --gpus all -v ./$TARGET_DIR:/workspace --workdir /workspace $1 python base_test.py

else
  echo "ERROR: no supported test shell"
  exit 1
fi


