Commit 5515c9a5 authored by Paul's avatar Paul
Browse files

Fix wrong global size

parent b4c4234d
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_FINALIZE)
struct module_impl struct module_impl
{ {
// A list is used to keep references to an instruction stable // A list is used to keep references to an instruction stable
...@@ -553,8 +555,14 @@ instruction_ref module::find_dangling_reference() const ...@@ -553,8 +555,14 @@ instruction_ref module::find_dangling_reference() const
void module::finalize(context& ctx) void module::finalize(context& ctx)
{ {
const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{});
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
if (trace)
{
std::cout << "Finalize: ";
this->debug_print(ins);
}
ins->finalize(ctx); ins->finalize(ctx);
for(const auto& smod : ins->module_inputs()) for(const auto& smod : ins->module_inputs())
{ {
......
...@@ -30,7 +30,7 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs ...@@ -30,7 +30,7 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
auto len = input.lens()[axis]; auto len = input.lens()[axis];
if(stride != 0 and stride != 1) if(stride != 0 and stride != 1)
return 1; return 1;
if(len == 1) if(len == 1 and input.elements() > sizes.front())
return sizes.front(); return sizes.front();
auto it = std::find_if( auto it = std::find_if(
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; }); sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; });
......
...@@ -129,6 +129,10 @@ std::size_t compute_block_size(std::size_t n, std::size_t max_block_size) ...@@ -129,6 +129,10 @@ std::size_t compute_block_size(std::size_t n, std::size_t max_block_size)
operation compile_hip_code_object(const std::string& content, hip_compile_options options) operation compile_hip_code_object(const std::string& content, hip_compile_options options)
{ {
assert(options.global > 0);
assert(options.local > 0);
assert(options.inputs.size() > 0);
assert(options.inputs.size() == options.virtual_inputs.size() or options.virtual_inputs.empty());
std::vector<src_file> srcs; std::vector<src_file> srcs;
std::transform(migraphx_kernels().begin(), std::transform(migraphx_kernels().begin(),
migraphx_kernels().end(), migraphx_kernels().end(),
......
...@@ -59,6 +59,8 @@ void launch_kernel(hipFunction_t fun, ...@@ -59,6 +59,8 @@ void launch_kernel(hipFunction_t fun,
void* kernargs, void* kernargs,
std::size_t size) std::size_t size)
{ {
assert(global > 0);
assert(local > 0);
void* config[] = { void* config[] = {
// HIP_LAUNCH_PARAM_* are macros that do horrible things // HIP_LAUNCH_PARAM_* are macros that do horrible things
#ifdef MIGRAPHX_USE_CLANG_TIDY #ifdef MIGRAPHX_USE_CLANG_TIDY
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment