Commit 6ea64443 authored by Astha Rai's avatar Astha Rai
Browse files

fixed descriptor and isSupportedArgument stride problem

parent fd87d533
...@@ -98,7 +98,6 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -98,7 +98,6 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
desc_mnk, desc_mnk,
make_tuple(make_right_pad_transform(m, pad_m), make_right_pad_transform(n, pad_n), make_right_pad_transform(k, pad_k)), make_tuple(make_right_pad_transform(m, pad_m), make_right_pad_transform(n, pad_n), make_right_pad_transform(k, pad_k)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return desc_mnk_pad; return desc_mnk_pad;
} }
...@@ -225,8 +224,8 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -225,8 +224,8 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
{ {
index_t gridSize = getAvailableComputeUnitCount(stream_config); index_t gridSize = getAvailableComputeUnitCount(stream_config);
index_t num_threads_m = (gridSize * arg.blockSize_) / 16; index_t num_threads_m = (gridSize * arg.blockSize_) / 16;
index_t num_threads_n = 4; index_t num_threads_n = 16;
index_t num_threads_k = 4; index_t num_threads_k = 16;
auto in_grid_3d_desc_tuple = generate_tuple( auto in_grid_3d_desc_tuple = generate_tuple(
[&](auto I) { [&](auto I) {
...@@ -235,7 +234,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -235,7 +234,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
gridSize, gridSize,
arg.blockSize_, arg.blockSize_,
num_threads_m, num_threads_m,
num_threads_n,); num_threads_n,
num_threads_k); num_threads_k);
}, },
Number<NumInput>{}); Number<NumInput>{});
...@@ -247,7 +246,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -247,7 +246,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
gridSize, gridSize,
arg.blockSize_, arg.blockSize_,
num_threads_m, num_threads_m,
num_threads_n,); num_threads_n,
num_threads_k); num_threads_k);
}, },
Number<NumOutput>{}); Number<NumOutput>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment