Unverified Commit 3c301efa authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

fix a bug in create tensor_view with vec data type (#1155)

When create a tensor_view with vector date type, the last dimension of the shape should be divided by the vec_size.
parent 1e0bbd78
...@@ -176,7 +176,13 @@ template <index_int N, class T, class... Ts> ...@@ -176,7 +176,13 @@ template <index_int N, class T, class... Ts>
auto hip_vec_visit_all(T&& x, Ts&&... xs) auto hip_vec_visit_all(T&& x, Ts&&... xs)
{ {
return [&](auto f) { return [&](auto f) {
hip_visit_all_impl(get_shape(x), auto sx = get_shape(x);
auto lens = sx.lens();
assert(lens.back() % N == 0);
assert(sx.strides().back() == 1);
lens.back() /= N;
shape vec_sx{sx.type(), lens};
hip_visit_all_impl(vec_sx,
make_hip_convert([](auto* p) { return as_vec<N>(device_cast(p)); }), make_hip_convert([](auto* p) { return as_vec<N>(device_cast(p)); }),
f, f,
x, x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment