Commit 3ce226ae authored by wenjh's avatar wenjh
Browse files

[DCU] Fix failed test cases



Due to the difference of warp size between nvidia(32) and dtk(64), the
OperatorTest/CTDBiasTestSuite.TestCTDBias/* are all failed except:

* OperatorTest/CTDBiasTestSuite.TestCTDBias/bfloat16Xfloat32X65536X128
* OperatorTest/CTDBiasTestSuite.TestCTDBias/bfloat16Xfloat16X65536X128
* OperatorTest/CTDBiasTestSuite.TestCTDBias/bfloat16Xbfloat16X65536X128
* OperatorTest/CTDBiasTestSuite.TestCTDBias/bfloat16Xfloat8e5m2X65536X128
* OperatorTest/CTDBiasTestSuite.TestCTDBias/bfloat16Xfloat8e4m3X65536X128

This commit is intended to fix this.
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 46c81675
...@@ -171,7 +171,8 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out], ...@@ -171,7 +171,8 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out],
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
CType elt = step_dbias.data.elt[j]; CType elt = step_dbias.data.elt[j];
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
elt = __shfl(elt, dbias_shfl_src_lane); // shuffle data in a warp elt = __shfl(elt, dbias_shfl_src_lane, THREADS_PER_WARP); // shuffle data in a warp
__syncthreads();
#else #else
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
#endif #endif
......
...@@ -91,7 +91,8 @@ inline __device__ void cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_O ...@@ -91,7 +91,8 @@ inline __device__ void cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_O
for (unsigned int j = 0; j < NVEC_IN; ++j) { for (unsigned int j = 0; j < NVEC_IN; ++j) {
CType elt = step_dbias.data.elt[j]; CType elt = step_dbias.data.elt[j];
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
elt = __shfl(elt, dbias_shfl_src_lane); // shuffle data in a warp elt = __shfl(elt, dbias_shfl_src_lane, THREADS_PER_WARP); // shuffle data in a warp
__syncthreads();
#else #else
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment