Commit 0e621bb8 authored by Michael Figurnov's avatar Michael Figurnov
Browse files

Refactoring of precision calculation.

Uses more NumPy now :)
parent a5c4fd06
...@@ -61,9 +61,6 @@ def train(hps): ...@@ -61,9 +61,6 @@ def train(hps):
sess = sv.prepare_or_wait_for_session() sess = sv.prepare_or_wait_for_session()
step = 0 step = 0
total_prediction = 0
correct_prediction = 0
precision = 0.0
lrn_rate = 0.1 lrn_rate = 0.1
while not sv.should_stop(): while not sv.should_stop():
...@@ -81,14 +78,9 @@ def train(hps): ...@@ -81,14 +78,9 @@ def train(hps):
else: else:
lrn_rate = 0.0001 lrn_rate = 0.0001
predictions = np.argmax(predictions, axis=1)
truth = np.argmax(truth, axis=1) truth = np.argmax(truth, axis=1)
for (t, p) in zip(truth, predictions): predictions = np.argmax(predictions, axis=1)
if t == p: precision = np.mean(truth == predictions)
correct_prediction += 1
total_prediction += 1
precision = float(correct_prediction) / total_prediction
correct_prediction = total_prediction = 0
step += 1 step += 1
if step % 100 == 0: if step % 100 == 0:
...@@ -135,12 +127,10 @@ def evaluate(hps): ...@@ -135,12 +127,10 @@ def evaluate(hps):
[model.summaries, model.cost, model.predictions, [model.summaries, model.cost, model.predictions,
model.labels, model.global_step]) model.labels, model.global_step])
best_predictions = np.argmax(predictions, axis=1)
truth = np.argmax(truth, axis=1) truth = np.argmax(truth, axis=1)
for (t, p) in zip(truth, best_predictions): predictions = np.argmax(predictions, axis=1)
if t == p: correct_prediction += np.sum(truth == predictions)
correct_prediction += 1 total_prediction += predictions.shape[0]
total_prediction += 1
precision = 1.0 * correct_prediction / total_prediction precision = 1.0 * correct_prediction / total_prediction
best_precision = max(precision, best_precision) best_precision = max(precision, best_precision)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment