Commit 0a59f103 authored by Scott Thornton's avatar Scott Thornton
Browse files

Fixed up MNIST example

parent 09c946b2
...@@ -298,7 +298,9 @@ struct reshape ...@@ -298,7 +298,9 @@ struct reshape
rdims.pop_back(); rdims.pop_back();
std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims)); std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
} }
return {inputs.front().type(), rdims}; shape s{inputs.front().type(), rdims};
if (s.elements() != inputs.front().elements()) RTG_THROW("Wrong number of elements");
return s;
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
......
...@@ -110,6 +110,20 @@ std::vector<int32_t> read_mnist_labels(std::string full_path, int& number_of_lab ...@@ -110,6 +110,20 @@ std::vector<int32_t> read_mnist_labels(std::string full_path, int& number_of_lab
} }
} }
std::vector<float> softmax(std::vector<float> p) {
size_t n = p.size();
std::vector<float> result(n);
float s = 0.0f;
for (size_t i = 0; i < n; i++) {
result[i] = std::exp(p[i]);
s += result[i];
}
for (size_t i = 0; i < n; i++) {
result[i] = result[i]/s;
}
return result;
}
int main(int argc, char const* argv[]) int main(int argc, char const* argv[])
{ {
if(argc > 1) if(argc > 1)
...@@ -122,26 +136,24 @@ int main(int argc, char const* argv[]) ...@@ -122,26 +136,24 @@ int main(int argc, char const* argv[])
std::vector<float> input = read_mnist_images(datafile, nimages, image_size); std::vector<float> input = read_mnist_images(datafile, nimages, image_size);
std::vector<int32_t> labels = read_mnist_labels(labelfile, nlabels); std::vector<int32_t> labels = read_mnist_labels(labelfile, nlabels);
printf("label: %d\n\n", labels[0]);
for(int i = 7; i < 9; i++)
{
for(int j = 0; j < 28; j++)
{
printf("%8.5f ", input[i * 28 + j]);
}
printf("\n");
}
std::string file = argv[1]; std::string file = argv[1];
auto prog = rtg::parse_onnx(file); auto prog = rtg::parse_onnx(file);
prog.compile(rtg::cpu::cpu_target{}); prog.compile(rtg::cpu::cpu_target{});
// auto s = prog.get_parameter_shape("Input3"); // auto s = prog.get_parameter_shape("Input3");
auto s = rtg::shape{rtg::shape::float_type, {1, 1, 28, 28}}; auto s = rtg::shape{rtg::shape::float_type, {1, 1, 28, 28}};
std::cout << s << std::endl; std::cout << s << std::endl;
auto input3 = rtg::argument{s, input.data()}; auto ptr = input.data();
auto out = prog.eval({{"Input3", input3}}); for (int i = 0; i < 20; i++)
std::cout << out << std::endl; {
std::cout << prog << std::endl; printf("label: %d ----> ", labels[i]);
auto input3 = rtg::argument{s, &ptr[784*i]};
auto result = prog.eval({{"Input3", input3}});
std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax(logits);
for (auto x : probs) printf("%8.4f ", x);
printf("\n");
}
printf("\n");
} }
} }
...@@ -141,7 +141,6 @@ argument program::eval(std::unordered_map<std::string, argument> params) const ...@@ -141,7 +141,6 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
values.begin(), values.begin(),
[&](instruction_ref i) { return results.at(std::addressof(*i)); }); [&](instruction_ref i) { return results.at(std::addressof(*i)); });
result = ins.op.compute(ins.result, values); result = ins.op.compute(ins.result, values);
std::cout << "Debug: " << ins.op.name() << "\n";
if(result.get_shape().elements() > 0 and result.get_shape().packed() and if(result.get_shape().elements() > 0 and result.get_shape().packed() and
std::isnan(result.at<float>())) std::isnan(result.at<float>()))
std::cout << "Nan: " << ins.op.name() << std::endl; std::cout << "Nan: " << ins.op.name() << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment