Commit a18e6481 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Tiny fix in using data type template parameters in blockwise and direct_threadwise kernel

parent 9e80cdce
...@@ -281,7 +281,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -281,7 +281,7 @@ struct GridwiseReduction_xy_to_x_blockwise
ThreadClusterLengths, ThreadClusterLengths,
Sequence<0, 1>, Sequence<0, 1>,
srcDataType, srcDataType,
dstDataType, compType,
src2dDescType, src2dDescType,
decltype(in_block_desc), decltype(in_block_desc),
Sequence<0, 1>, Sequence<0, 1>,
......
...@@ -232,7 +232,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -232,7 +232,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<srcDataType, auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
dstDataType, compType,
src2dDescType, src2dDescType,
decltype(ThreadBufferDesc), decltype(ThreadBufferDesc),
ThreadBufferLengths, ThreadBufferLengths,
...@@ -377,7 +377,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -377,7 +377,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2<srcDataType, auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
dstDataType, compType,
src2dDescType, src2dDescType,
decltype(ThreadBufferDesc), decltype(ThreadBufferDesc),
ThreadBufferLengths, ThreadBufferLengths,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment