Commit e8858300 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed a bug with block size

parent 40a5f496
...@@ -25,7 +25,7 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg) ...@@ -25,7 +25,7 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg)
auto* in_ptr = device_cast(input.data()); auto* in_ptr = device_cast(input.data());
visit_tensor_size(out_lens.size(), [&](auto out_dim) { visit_tensor_size(out_lens.size(), [&](auto out_dim) {
hip_tensor_descriptor<out_dim> desc(comp_shape); hip_tensor_descriptor<out_dim> desc(comp_shape);
gs_launch(stream, nelements)([=](auto ii) { gs_launch(stream, nelements, 256)([=](auto ii) {
const size_t nb = 4; const size_t nb = 4;
auto idx = desc.multi(ii); auto idx = desc.multi(ii);
std::size_t i_m = idx[dim_1]; std::size_t i_m = idx[dim_1];
...@@ -56,7 +56,7 @@ void pack_b(hipStream_t stream, const argument& result, const argument& arg) ...@@ -56,7 +56,7 @@ void pack_b(hipStream_t stream, const argument& result, const argument& arg)
auto* in_ptr = device_cast(input.data()); auto* in_ptr = device_cast(input.data());
visit_tensor_size(out_lens.size(), [&](auto out_dim) { visit_tensor_size(out_lens.size(), [&](auto out_dim) {
hip_tensor_descriptor<out_dim> desc(comp_shape); hip_tensor_descriptor<out_dim> desc(comp_shape);
gs_launch(stream, nelements)([=](auto ii) { gs_launch(stream, nelements, 256)([=](auto ii) {
const size_t nb = 4; const size_t nb = 4;
auto idx = desc.multi(ii); auto idx = desc.multi(ii);
std::size_t i_n = idx[dim_1]; std::size_t i_n = idx[dim_1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment