Commit 94256bc4 authored by Paul's avatar Paul
Browse files

Update block size calculation

parent 0ee486c5
...@@ -144,10 +144,9 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over) ...@@ -144,10 +144,9 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t compute_block_size(std::size_t n, std::size_t max_block_size) std::size_t compute_block_size(std::size_t n, std::size_t max_block_size)
{ {
size_t block_size = 128; const std::size_t min_block_size = 64;
while(block_size <= max_block_size and block_size <= n) auto block_size = (((n - 1) / min_block_size + 1)) * min_block_size;
block_size *= 2; return std::min(block_size, max_block_size);
return block_size / 2;
} }
operation compile_hip_code_object(const std::string& content, hip_compile_options options) operation compile_hip_code_object(const std::string& content, hip_compile_options options)
......
...@@ -62,7 +62,7 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -62,7 +62,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
auto preloads = preload::broadcasts(axis, inputs); auto preloads = preload::broadcasts(axis, inputs);
auto relements = inputs[0].lens()[axis] / vec.size; auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = inputs.back().elements() / relements; auto nelements = inputs.back().elements() / relements;
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 512);
hip_compile_options options; hip_compile_options options;
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size); v, compute_global_for(ctx, nelements * block_size, 256), block_size);
......
...@@ -58,7 +58,7 @@ struct softmax_compiler : compiler<softmax_compiler> ...@@ -58,7 +58,7 @@ struct softmax_compiler : compiler<softmax_compiler>
} }
auto relements = inputs[0].lens()[axis] / vec.size; auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = inputs.back().elements() / relements; auto nelements = inputs.back().elements() / relements;
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 512);
hip_compile_options options; hip_compile_options options;
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size); v, compute_global_for(ctx, nelements * block_size, 256), block_size);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment