Commit 0e621bb8 authored by Michael Figurnov's avatar Michael Figurnov
Browse files

Refactoring of precision calculation.

Uses more NumPy now :)
parent a5c4fd06
......@@ -61,9 +61,6 @@ def train(hps):
sess = sv.prepare_or_wait_for_session()
step = 0
total_prediction = 0
correct_prediction = 0
precision = 0.0
lrn_rate = 0.1
while not sv.should_stop():
......@@ -81,14 +78,9 @@ def train(hps):
else:
lrn_rate = 0.0001
predictions = np.argmax(predictions, axis=1)
truth = np.argmax(truth, axis=1)
for (t, p) in zip(truth, predictions):
if t == p:
correct_prediction += 1
total_prediction += 1
precision = float(correct_prediction) / total_prediction
correct_prediction = total_prediction = 0
predictions = np.argmax(predictions, axis=1)
precision = np.mean(truth == predictions)
step += 1
if step % 100 == 0:
......@@ -135,12 +127,10 @@ def evaluate(hps):
[model.summaries, model.cost, model.predictions,
model.labels, model.global_step])
best_predictions = np.argmax(predictions, axis=1)
truth = np.argmax(truth, axis=1)
for (t, p) in zip(truth, best_predictions):
if t == p:
correct_prediction += 1
total_prediction += 1
predictions = np.argmax(predictions, axis=1)
correct_prediction += np.sum(truth == predictions)
total_prediction += predictions.shape[0]
precision = 1.0 * correct_prediction / total_prediction
best_precision = max(precision, best_precision)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment