predict_demo.py 346 Bytes
Newer Older
mashun's avatar
mashun committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import tensorflow as tf
import numpy as np

from keras_train import Dataset


# 加载保存的 Keras 模型 (.h5)
model = tf.keras.models.load_model('model_checkpoints/particle_net_lite_model.001.h5')

val_dataset = Dataset('converted/val_file_0.awkd', data_format='channel_last')

predictions = model.predict(val_dataset.X)

print(predictions)