"src/targets/gpu/gemm_impl.cpp" did not exist on "fc03c7060c4f1ea43f199248001899c4d0fd4d71"
Commit c0d3babb authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed a bug in calculating indices.

parent 5fe89b69
...@@ -61,23 +61,19 @@ struct hip_tensor_descriptor ...@@ -61,23 +61,19 @@ struct hip_tensor_descriptor
{ {
std::copy(s.lens().begin(), s.lens().end(), lens); std::copy(s.lens().begin(), s.lens().end(), lens);
std::copy(s.strides().begin(), s.strides().end(), strides); std::copy(s.strides().begin(), s.strides().end(), strides);
std::vector<std::size_t> vec_idx(s.lens().size());
std::iota(vec_idx.begin(), vec_idx.end(), 0);
std::sort(vec_idx.begin(), vec_idx.end(), [&](size_t i, size_t j) {
return strides[i] > strides[j];
});
std::copy(vec_idx.begin(), vec_idx.end(), indices);
} }
__device__ __host__ hip_index<NDim> multi(size_t idx) const __device__ __host__ hip_index<NDim> multi(size_t idx) const
{ {
hip_index<NDim> result{}; hip_index<NDim> result{};
size_t tidx = idx; size_t tidx = idx;
for(size_t is = 0; is < NDim; is++) for(size_t is = 0; is < NDim; is++)
{ {
result[indices[is]] = tidx / strides[indices[is]]; result[is] = tidx / strides[is];
tidx = tidx % strides[indices[is]]; tidx = tidx % strides[is];
} }
return result; return result;
} }
...@@ -90,7 +86,6 @@ struct hip_tensor_descriptor ...@@ -90,7 +86,6 @@ struct hip_tensor_descriptor
} }
size_t lens[NDim] = {}; size_t lens[NDim] = {};
size_t strides[NDim] = {}; size_t strides[NDim] = {};
size_t indices[NDim] = {};
}; };
} // namespace device } // namespace device
......
...@@ -40,11 +40,15 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg) ...@@ -40,11 +40,15 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg)
void pack_b(hipStream_t stream, const argument& result, const argument& arg) void pack_b(hipStream_t stream, const argument& result, const argument& arg)
{ {
auto output_shape = result.get_shape(); auto trans_shape = result.get_shape();
auto out_lens = output_shape.lens(); auto out_lens = trans_shape.lens();
auto dim_0 = output_shape.lens().size() - 2; auto dim_0 = trans_shape.lens().size() - 2;
auto dim_1 = output_shape.lens().size() - 1; auto dim_1 = trans_shape.lens().size() - 1;
std::size_t ldb = output_shape.strides()[dim_1]; std::size_t ldb = trans_shape.strides()[dim_1];
auto wrap_lens = out_lens;
std::swap(wrap_lens[dim_0], wrap_lens[dim_1]);
shape output_shape{trans_shape.type(), wrap_lens};
std::size_t m_size = out_lens[dim_0] * out_lens[dim_1]; std::size_t m_size = out_lens[dim_0] * out_lens[dim_1];
visit_all(result, arg)([&](auto output, auto input) { visit_all(result, arg)([&](auto output, auto input) {
std::size_t nelements = output_shape.elements(); std::size_t nelements = output_shape.elements();
...@@ -55,8 +59,8 @@ void pack_b(hipStream_t stream, const argument& result, const argument& arg) ...@@ -55,8 +59,8 @@ void pack_b(hipStream_t stream, const argument& result, const argument& arg)
gs_launch(stream, nelements)([=](auto ii) { gs_launch(stream, nelements)([=](auto ii) {
const size_t nb = 4; const size_t nb = 4;
auto idx = desc.multi(ii); auto idx = desc.multi(ii);
std::size_t i_n = idx[dim_0]; std::size_t i_n = idx[dim_1];
std::size_t i_k = idx[dim_1]; std::size_t i_k = idx[dim_0];
std::size_t offset = ii / m_size * m_size; std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb + offset] = out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb + offset] =
in_ptr[i_n + i_k * ldb + offset]; in_ptr[i_n + i_k * ldb + offset];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment