Commit 12e9c7d5 authored by Scott Thornton's avatar Scott Thornton
Browse files

Formatting

parent 056c6fe2
......@@ -96,12 +96,13 @@ std::vector<int32_t> read_mnist_labels(std::string full_path, int& number_of_lab
}
}
std::vector<float> softmax(std::vector<float> p) {
std::vector<float> softmax(std::vector<float> p)
{
size_t n = p.size();
std::vector<float> result(n);
std::transform(p.begin(), p.end(), result.begin(), [] (auto x) {return std::exp(x);});
std::transform(p.begin(), p.end(), result.begin(), [](auto x) { return std::exp(x); });
float s = std::accumulate(result.begin(), result.end(), 0.0f, std::plus<float>());
std::transform(result.begin(), result.end(), result.begin(), [=] (auto x) {return x/s;});
std::transform(result.begin(), result.end(), result.begin(), [=](auto x) { return x / s; });
return result;
}
......@@ -123,15 +124,16 @@ int main(int argc, char const* argv[])
auto s = migraph::shape{migraph::shape::float_type, {1, 1, 28, 28}};
std::cout << s << std::endl;
auto ptr = input.data();
for (int i = 0; i < 20; i++)
for(int i = 0; i < 20; i++)
{
std::cout << "label: " << labels[i] << " ----> ";
auto input3 = migraph::argument{s, &ptr[784*i]};
auto input3 = migraph::argument{s, &ptr[784 * i]};
auto result = prog.eval({{"Input3", input3}});
std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax(logits);
for (auto x : probs) std::cout << x << " ";
for(auto x : probs)
std::cout << x << " ";
std::cout << std::endl;
}
std::cout << std::endl;
......
......@@ -621,7 +621,7 @@ int main()
transpose_test();
contiguous_test();
softmax_test();
//maxpool_test();
// maxpool_test();
conv2d_test();
conv2d_padding_test();
conv2d_padding_stride_test();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment