Unverified Commit f8ed456e authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

[Fix] Implement movmatrix using warp shuffling for CUDA < 11.8 (#267)

parent 903707b5
...@@ -26,7 +26,26 @@ mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<ha ...@@ -26,7 +26,26 @@ mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<ha
#endif #endif
} }
__inline__ __device__ uint transpose_m8n8_b16(uint a) __inline__ __device__ uint transpose_m8n8_b16_warp_shuffle(uint value, int lane_id)
{
int src_lane = lane_id / 8 + lane_id % 4 * 8;
uint u0 = __shfl_sync(0xffffffff, value, src_lane);
uint u1 = __shfl_sync(0xffffffff, value, src_lane + 4);
short2 r;
if (lane_id % 8 < 4) {
r.x = ((short2&)u0).x;
r.y = ((short2&)u1).x;
}
else {
r.x = ((short2&)u0).y;
r.y = ((short2&)u1).y;
}
return (uint&)r;
}
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
__inline__ __device__ uint transpose_m8n8_b16_movmatrix(uint a)
{ {
#if TURBOMIND_ARCH_SM75 #if TURBOMIND_ARCH_SM75
uint d; uint d;
...@@ -37,6 +56,18 @@ __inline__ __device__ uint transpose_m8n8_b16(uint a) ...@@ -37,6 +56,18 @@ __inline__ __device__ uint transpose_m8n8_b16(uint a)
return 0; return 0;
#endif #endif
} }
#endif
__inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id)
{
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)
(void)lane_id;
return transpose_m8n8_b16_movmatrix(a);
#else
return transpose_m8n8_b16_warp_shuffle(a, lane_id);
#endif
}
namespace ops { namespace ops {
...@@ -246,7 +277,7 @@ struct Gemm { ...@@ -246,7 +277,7 @@ struct Gemm {
// convert to half // convert to half
half2 half_C = __float22half2_rn(frag_C[j * 2 + x]); half2 half_C = __float22half2_rn(frag_C[j * 2 + x]);
// transpose 8x8 accum tile // transpose 8x8 accum tile
uint trans_C = transpose_m8n8_b16((uint&)half_C); uint trans_C = transpose_m8n8_b16((uint&)half_C, lane_id);
// store to global memory // store to global memory
OutputOps::template apply<Index>(trans_C, mm, nn, C, m, n); OutputOps::template apply<Index>(trans_C, mm, nn, C, m, n);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment