Commit a7666964 authored by qianyj's avatar qianyj
Browse files

update TensorFlow test code

parent 3f2973d6
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
## 运行 ## 运行
# sed指令只需要执行一次,添加支持多卡运行的代码 # sed指令只需要执行一次,添加支持多卡运行的代码
sed -i '99 r configfile' official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py sed -i '101 r configfile' official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
export PYTHONPATH=/path/to/tensorflow/model:$PYTHONPATH export PYTHONPATH=/path/to/tensorflow/model:$PYTHONPATH
mpirun -np ${num_gpu} --hostfile hostfile -mca btl self,tcp --bind-to none scripts-run/single_process.sh mpirun -np ${num_gpu} --hostfile hostfile -mca btl self,tcp --bind-to none scripts-run/single_process.sh
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
## 运行 ## 运行
# sed指令只需要执行一次,添加支持多卡运行的代码 # sed指令只需要执行一次,添加支持多卡运行的代码
sed -i '99 r configfile' official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py sed -i '101 r configfile' official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
修改scripts-run/single_process.sh中的--dtype=fp16 修改scripts-run/single_process.sh中的--dtype=fp16
export PYTHONPATH=/path/to/tensorflow/model:$PYTHONPATH export PYTHONPATH=/path/to/tensorflow/model:$PYTHONPATH
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import math import math
import os import os
import json
# Import libraries # Import libraries
from absl import app from absl import app
...@@ -97,6 +98,7 @@ def run(flags_obj): ...@@ -97,6 +98,7 @@ def run(flags_obj):
Returns: Returns:
Dictionary of training and eval stats. Dictionary of training and eval stats.
""" """
keras_utils.set_session_config() keras_utils.set_session_config()
performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
## 分布式多卡 ## 分布式多卡
# sed指令只需要执行一次,添加支持多卡运行的代码 # sed指令只需要执行一次,添加支持多卡运行的代码
sed -i '99 r configfile' models-master/official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py sed -i '101 r configfile' models-master/official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
export PYTHONPATH=/path/to/tensorflow/model:$PYTHONPATH export PYTHONPATH=/path/to/tensorflow/model:$PYTHONPATH
mpirun -np ${num_gpu} --hostfile hostfile -mca btl self,tcp --bind-to none scripts-run/single_process.sh mpirun -np ${num_gpu} --hostfile hostfile -mca btl self,tcp --bind-to none scripts-run/single_process.sh
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import math import math
import os import os
import json import json
# Import libraries # Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
...@@ -97,6 +98,7 @@ def run(flags_obj): ...@@ -97,6 +98,7 @@ def run(flags_obj):
Returns: Returns:
Dictionary of training and eval stats. Dictionary of training and eval stats.
""" """
keras_utils.set_session_config() keras_utils.set_session_config()
performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment