Commit c0d3babb authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed a bug in calculating indices.

parent 5fe89b69
......@@ -61,23 +61,19 @@ struct hip_tensor_descriptor
{
std::copy(s.lens().begin(), s.lens().end(), lens);
std::copy(s.strides().begin(), s.strides().end(), strides);
std::vector<std::size_t> vec_idx(s.lens().size());
std::iota(vec_idx.begin(), vec_idx.end(), 0);
std::sort(vec_idx.begin(), vec_idx.end(), [&](size_t i, size_t j) {
return strides[i] > strides[j];
});
std::copy(vec_idx.begin(), vec_idx.end(), indices);
}
__device__ __host__ hip_index<NDim> multi(size_t idx) const
{
hip_index<NDim> result{};
size_t tidx = idx;
for(size_t is = 0; is < NDim; is++)
{
result[indices[is]] = tidx / strides[indices[is]];
tidx = tidx % strides[indices[is]];
result[is] = tidx / strides[is];
tidx = tidx % strides[is];
}
return result;
}
......@@ -90,7 +86,6 @@ struct hip_tensor_descriptor
}
size_t lens[NDim] = {};
size_t strides[NDim] = {};
size_t indices[NDim] = {};
};
} // namespace device
......
......@@ -40,11 +40,15 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg)
void pack_b(hipStream_t stream, const argument& result, const argument& arg)
{
auto output_shape = result.get_shape();
auto out_lens = output_shape.lens();
auto dim_0 = output_shape.lens().size() - 2;
auto dim_1 = output_shape.lens().size() - 1;
std::size_t ldb = output_shape.strides()[dim_1];
auto trans_shape = result.get_shape();
auto out_lens = trans_shape.lens();
auto dim_0 = trans_shape.lens().size() - 2;
auto dim_1 = trans_shape.lens().size() - 1;
std::size_t ldb = trans_shape.strides()[dim_1];
auto wrap_lens = out_lens;
std::swap(wrap_lens[dim_0], wrap_lens[dim_1]);
shape output_shape{trans_shape.type(), wrap_lens};
std::size_t m_size = out_lens[dim_0] * out_lens[dim_1];
visit_all(result, arg)([&](auto output, auto input) {
std::size_t nelements = output_shape.elements();
......@@ -55,8 +59,8 @@ void pack_b(hipStream_t stream, const argument& result, const argument& arg)
gs_launch(stream, nelements)([=](auto ii) {
const size_t nb = 4;
auto idx = desc.multi(ii);
std::size_t i_n = idx[dim_0];
std::size_t i_k = idx[dim_1];
std::size_t i_n = idx[dim_1];
std::size_t i_k = idx[dim_0];
std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb + offset] =
in_ptr[i_n + i_k * ldb + offset];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment