"include/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "095c2a4948b8b28f3fc65108629fa04cc45daa33"
Commit 1e02e941 authored by Paul's avatar Paul
Browse files

Adjust block_size for navi

parent 308db690
......@@ -164,9 +164,9 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
};
}
std::size_t compute_block_size(std::size_t n, std::size_t max_block_size)
std::size_t compute_block_size(context& ctx, std::size_t n, std::size_t max_block_size)
{
const std::size_t min_block_size = 64;
const std::size_t min_block_size = ctx.get_current_device().get_wavefront_size();
auto block_size = (((n - 1) / min_block_size + 1)) * min_block_size;
return std::min(std::max(min_block_size, block_size), max_block_size);
}
......
......@@ -72,7 +72,7 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over = 1);
MIGRAPHX_GPU_EXPORT operation compile_hip_code_object(const std::string& content,
hip_compile_options options);
MIGRAPHX_GPU_EXPORT std::size_t compute_block_size(std::size_t n,
MIGRAPHX_GPU_EXPORT std::size_t compute_block_size(context& ctx, std::size_t n,
std::size_t max_block_size = 1024);
MIGRAPHX_GPU_EXPORT std::string generate_make_shape(const shape& s);
......
......@@ -166,7 +166,7 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
if(algo == "block")
{
auto block_size = compute_block_size(relements, 256);
auto block_size = compute_block_size(ctx, relements, 256);
if(relements >= block_size * 256)
algo = "block_large";
options.set_launch_params(
......@@ -274,7 +274,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto relements = reduction_shape.elements() / vec.size;
if(algo == "block")
{
auto block_size = compute_block_size(relements, 256);
auto block_size = compute_block_size(ctx, relements, 256);
if(relements >= block_size * 256)
algo = "block_large";
options.set_launch_params(
......
......@@ -75,7 +75,7 @@ struct softmax_compiler : compiler<softmax_compiler>
}
auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]);
auto block_size = compute_block_size(relements, 256);
auto block_size = compute_block_size(ctx, relements, 256);
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment