Commit 19d2b713 authored by wsttiger's avatar wsttiger
Browse files

Forgot to add these files to the previous commit

parent a36c7a1f
...@@ -163,16 +163,6 @@ struct pooling ...@@ -163,16 +163,6 @@ struct pooling
std::ptrdiff_t(std::floor((input.lens()[3] + 2 * padding[1] - lengths[1]) / std::ptrdiff_t(std::floor((input.lens()[3] + 2 * padding[1] - lengths[1]) /
static_cast<float>(stride[1]))) + static_cast<float>(stride[1]))) +
1)), 1)),
// std::size_t(std::max<std::ptrdiff_t>(
// 1,
// std::ptrdiff_t((input.lens()[2] + 2 * padding[0] - lengths[0]) /
// static_cast<float>(stride[0])) +
// 1)),
// std::size_t(std::max<std::ptrdiff_t>(
// 1,
// std::ptrdiff_t((input.lens()[3] + 2 * padding[1] - lengths[1]) /
// static_cast<float>(stride[1])) +
// 1)),
}}; }};
} }
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include <migraph/gpu/hip.hpp> #include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp> #include <migraph/generate.hpp>
#include "softmax.h"
auto read_cifar10_images(const std::string& full_path) auto read_cifar10_images(const std::string& full_path)
{ {
std::ifstream file(full_path, std::ios::binary); std::ifstream file(full_path, std::ios::binary);
...@@ -42,16 +44,6 @@ auto read_cifar10_images(const std::string& full_path) ...@@ -42,16 +44,6 @@ auto read_cifar10_images(const std::string& full_path)
} }
} }
std::vector<float> softmax(std::vector<float> p)
{
size_t n = p.size();
std::vector<float> result(n);
std::transform(p.begin(), p.end(), result.begin(), [](auto x) { return std::exp(x); });
float s = std::accumulate(result.begin(), result.end(), 0.0f, std::plus<float>());
std::transform(result.begin(), result.end(), result.begin(), [=](auto x) { return x / s; });
return result;
}
int main(int argc, char const* argv[]) int main(int argc, char const* argv[])
{ {
if(argc < 4) if(argc < 4)
...@@ -85,7 +77,7 @@ int main(int argc, char const* argv[]) ...@@ -85,7 +77,7 @@ int main(int argc, char const* argv[])
auto result = migraph::gpu::from_gpu(prog.eval(m)); auto result = migraph::gpu::from_gpu(prog.eval(m));
std::vector<float> logits; std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); }); result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax(logits); std::vector<float> probs = softmax<float>(logits);
for(auto x : probs) for(auto x : probs)
std::cout << x << " "; std::cout << x << " ";
std::cout << std::endl << std::endl; std::cout << std::endl << std::endl;
...@@ -106,7 +98,7 @@ int main(int argc, char const* argv[]) ...@@ -106,7 +98,7 @@ int main(int argc, char const* argv[])
auto result = prog.eval({{"0", input3}}); auto result = prog.eval({{"0", input3}});
std::vector<float> logits; std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); }); result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax(logits); std::vector<float> probs = softmax<float>(logits);
for(auto x : probs) for(auto x : probs)
std::cout << x << " "; std::cout << x << " ";
std::cout << std::endl; std::cout << std::endl;
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include <migraph/gpu/hip.hpp> #include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp> #include <migraph/generate.hpp>
#include "softmax.h"
auto reverse_int(unsigned int i) auto reverse_int(unsigned int i)
{ {
unsigned char c1, c2, c3, c4; unsigned char c1, c2, c3, c4;
...@@ -98,16 +100,6 @@ std::vector<int32_t> read_mnist_labels(const std::string& full_path, int& number ...@@ -98,16 +100,6 @@ std::vector<int32_t> read_mnist_labels(const std::string& full_path, int& number
} }
} }
std::vector<float> softmax(std::vector<float> p)
{
size_t n = p.size();
std::vector<float> result(n);
std::transform(p.begin(), p.end(), result.begin(), [](auto x) { return std::exp(x); });
float s = std::accumulate(result.begin(), result.end(), 0.0f, std::plus<float>());
std::transform(result.begin(), result.end(), result.begin(), [=](auto x) { return x / s; });
return result;
}
int main(int argc, char const* argv[]) int main(int argc, char const* argv[])
{ {
if(argc > 3) if(argc > 3)
......
...@@ -659,7 +659,7 @@ int main() ...@@ -659,7 +659,7 @@ int main()
gemm_test<double>(); gemm_test<double>();
reshape_test(); reshape_test();
transpose_test(); transpose_test();
// contiguous_test(); contiguous_test();
softmax_test(); softmax_test();
// maxpool_test(); // maxpool_test();
conv2d_test(); conv2d_test();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment