Commit 735e102a authored by Paul's avatar Paul
Browse files

Formatting

parent e63c09c5
...@@ -98,7 +98,7 @@ struct check_shapes ...@@ -98,7 +98,7 @@ struct check_shapes
const check_shapes& not_broadcasted() const const check_shapes& not_broadcasted() const
{ {
// if(!this->all_of([](const shape& s) { return not s.broadcasted(); })) // if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
// MIGRAPH_THROW(prefix() + "Shapes are broadcasted"); // MIGRAPH_THROW(prefix() + "Shapes are broadcasted");
return *this; return *this;
} }
......
...@@ -12,7 +12,7 @@ struct index ...@@ -12,7 +12,7 @@ struct index
std::size_t group; std::size_t group;
}; };
template<class F> template <class F>
__global__ void launcher(F f) __global__ void launcher(F f)
{ {
index idx{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; index idx{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x};
...@@ -27,12 +27,7 @@ auto launch(std::size_t global, std::size_t local) ...@@ -27,12 +27,7 @@ auto launch(std::size_t global, std::size_t local)
using f_type = decltype(f); using f_type = decltype(f);
dim3 nblocks(global / local); dim3 nblocks(global / local);
dim3 nthreads(local); dim3 nthreads(local);
hipLaunchKernelGGL((launcher<f_type>), hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, nullptr, f);
nblocks,
nthreads,
0,
nullptr,
f);
}; };
} }
...@@ -135,17 +130,17 @@ void hip_contiguous(migraph::shape output_shape, migraph::argument arg, migraph: ...@@ -135,17 +130,17 @@ void hip_contiguous(migraph::shape output_shape, migraph::argument arg, migraph:
const auto& s = arg.get_shape(); const auto& s = arg.get_shape();
hip_tensor_descriptor<ndim> a_desc(s.lens(), s.strides()); hip_tensor_descriptor<ndim> a_desc(s.lens(), s.strides());
hip_tensor_descriptor<ndim> at_desc(output_shape.lens(), output_shape.strides()); hip_tensor_descriptor<ndim> at_desc(output_shape.lens(), output_shape.strides());
auto* a = input.data(); auto* a = input.data();
auto* at = output.data(); auto* at = output.data();
auto nelements = s.elements(); auto nelements = s.elements();
std::size_t nlocal = 512; std::size_t nlocal = 512;
std::size_t nglobal = 512*nlocal; std::size_t nglobal = 512 * nlocal;
launch(nglobal, nlocal)([=](auto idx) mutable { launch(nglobal, nlocal)([=](auto idx) mutable {
for(size_t i = idx.global; i < nelements; i += nglobal) for(size_t i = idx.global; i < nelements; i += nglobal)
{ {
size_t lidx = a_desc.linear(at_desc.multi(i)); size_t lidx = a_desc.linear(at_desc.multi(i));
at[i] = a[lidx]; at[i] = a[lidx];
} }
}); });
}); });
......
...@@ -17,17 +17,12 @@ ...@@ -17,17 +17,12 @@
struct auto_eval struct auto_eval
{ {
migraph::program* p; migraph::program* p;
migraph::program::parameter_map * m; migraph::program::parameter_map* m;
migraph::argument result; migraph::argument result;
auto_eval(migraph::program& pp, migraph::program::parameter_map& pm) auto_eval(migraph::program& pp, migraph::program::parameter_map& pm) : p(&pp), m(&pm) {}
: p(&pp), m(&pm)
{}
migraph::argument operator()() const migraph::argument operator()() const { return p->eval(*m); }
{
return p->eval(*m);
}
~auto_eval() ~auto_eval()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment