#pragma once #include "philox.cuh" #include "utils.h" using namespace flash; template inline __device__ void apply_mask(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int max_seqlen_k, const int col_idx_offset_ = 0) { const int lane_id = threadIdx.x & 63; const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2; #pragma unroll for (int ni = 0; ni < (WARP_N / 32); ++ni) { #pragma unroll for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n; #pragma unroll for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { const int col_idx = col_idx_base + vec_idx * 8; if (col_idx >= max_seqlen_k) { #pragma unroll for (int mi = 0; mi < (WARP_M / 32); ++mi) { #pragma unroll for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) { tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY; } } } } } } } template inline __device__ void apply_dropout(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int max_seqlen_k, const int col_idx_offset_, unsigned long long seed, unsigned long long offset, uint32_t p_dropout_in_8bits_value, union_vec2_uint rowcol, uint32_t* dropout_debug_count) { const int lane_id = threadIdx.x & 63; // lane id, 0-63 const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2; // prepare 4 uint for 16 uint8 union_vec4_uint random_uint4; for (int mi = 0; mi < (WARP_M / 32); ++mi, rowcol.u32.x += BLOCK_ROW_STRIDE) { #pragma unroll for (uint32_t ni = 0; ni < (WARP_N / 32); ++ni, ++rowcol.u32.y) { // for each 16 elements, generate 16 int8 -> 4 u32 random_uint4.u32 = flash::philox(seed, rowcol.u64, offset); #pragma unroll for(uint32_t min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n; #pragma unroll for(uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) { if constexpr (Is_even_MN) { #pragma unroll for(uint32_t min_tile_m = 0; min_tile_m < 2; ++min_tile_m) { uint32_t cur_pos = (min_tile_n * 2 + min_tile_m) * 4 + vec_idx; uint32_t cur_rand = random_uint4.u8[cur_pos] & 0xffffffff; // uint8 -> u32, since hcu has no compare instructions with 8/16 bits if (cur_rand > p_dropout_in_8bits_value) { tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = 0x0; #ifdef FA_DEBUG atomicAdd(dropout_debug_count, 1); #endif } } } else if constexpr (not Is_even_MN) { const int col_idx = col_idx_base + vec_idx * 8; if (col_idx < max_seqlen_k) { #pragma unroll for(uint32_t min_tile_m = 0; min_tile_m < 2; ++min_tile_m) { uint32_t cur_pos = (min_tile_n * 2 + min_tile_m) * 4 + vec_idx; uint32_t cur_rand = random_uint4.u8[cur_pos] & 0xffffffff; // uint8 -> u32, since hcu has no compare instructions with 8/16 bits if (cur_rand > p_dropout_in_8bits_value) { tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = 0x0; #ifdef FA_DEBUG atomicAdd(dropout_debug_count, 1); #endif } } } } } } } } } template inline __device__ void apply_mask_causal(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_, const int max_seqlen_q) { const int lane_id = threadIdx.x & 63; const int row_idx_offset = row_idx_offset_ + (lane_id & 15) * 2; const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2; #pragma unroll for (int mi = 0; mi < (WARP_M / 32); ++mi) { const int row_idx_base = row_idx_offset + mi * 32; #pragma unroll for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) { const int row_idx = row_idx_base + min_tile_m; const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q); // attention, when max_seqlen_k == max_seqlen_q, vgpr can be reduced again #pragma unroll for (int ni = 0; ni < (WARP_N / 32); ++ni) { #pragma unroll for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n; #pragma unroll for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { const int col_idx = col_idx_base + vec_idx * 8; tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right) ? -INFINITY: tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]; } } } } } } template inline __device__ void apply_mask_local(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_, const int max_seqlen_q, const int window_size_left, const int window_size_right) { const int lane_id = threadIdx.x & 63; const int row_idx_offset = row_idx_offset_ + (lane_id & 15) * 2; const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2; #pragma unroll for (int mi = 0; mi < (WARP_M / 32); ++mi) { const int row_idx_base = row_idx_offset + mi * 32; #pragma unroll for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) { const int row_idx = row_idx_base + min_tile_m; const int col_idx_limit_left = std::max(0, row_idx + 1 + max_seqlen_k - max_seqlen_q - window_size_left); const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll for (int ni = 0; ni < (WARP_N / 32); ++ni) { #pragma unroll for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n; #pragma unroll for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { const int col_idx = col_idx_base + vec_idx * 8; tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] = (col_idx > col_idx_limit_right || (HasWSLeft && col_idx < (col_idx_limit_left - 1))) ? -INFINITY: tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]; } } } } } } template inline __device__ void apply_alibi(DataType tensor[(WARP_M / 32) * (WARP_N / 32)][4], const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_, const int max_seqlen_q, float gAlibi) { const int lane_id = threadIdx.x & 63; const int row_idx_offset = row_idx_offset_ + (lane_id & 15) * 2; const int col_idx_offset = col_idx_offset_ + (lane_id >> 4) * 2; #pragma unroll for (int mi = 0; mi < (WARP_M / 32); ++mi) { const int row_idx_base = row_idx_offset + mi * 32; #pragma unroll for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) { const int row_idx = row_idx_base + min_tile_m; #pragma unroll for (int ni = 0; ni < (WARP_N / 32); ++ni) { #pragma unroll for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { const int col_idx_base = col_idx_offset + ni * 32 + min_tile_n; #pragma unroll for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { const int col_idx = col_idx_base + vec_idx * 8; tensor[mi + ni * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx] += gAlibi * (col_idx - row_idx); } } } } } } template __device__ inline void thread_reduce_max(const DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) { if(zero_init == true) { #pragma unroll for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) { #pragma unroll for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) { summary[m_idx * 2].f32[min_tile_m] = -INFINITY; // OpType:0 is sum operator, 1 is max operator #pragma unroll for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) { #pragma unroll for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread summary[m_idx * 2].f32[min_tile_m] = op(summary[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]); } } } } } } else { #pragma unroll for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) { #pragma unroll for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) { summary_cur[m_idx * 2].f32[min_tile_m] = summary[m_idx * 2].f32[min_tile_m]; #pragma unroll for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) { #pragma unroll for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread summary_cur[m_idx * 2].f32[min_tile_m] = op(summary_cur[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]); } } } } } } } template __device__ inline void thread_reduce_sum(const DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) { if(zero_init == true) { #pragma unroll for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) { // 对于 gfx936 及以上的架构, 可以使用 v_pk_add_f32 #if defined(__gfx936__) || defined(__gfx938__) summary[m_idx * 2].u64 = 0x0; #pragma unroll for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) { #pragma unroll for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { __float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]}; summary[m_idx * 2].u64 = __builtin_hcu_pk_add_f32( summary[m_idx * 2].u64, additem_pair ); } } } #else #pragma unroll for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) { summary[m_idx * 2].f32[min_tile_m] = 0; // OpType:0 is sum operator, 1 is max operator #pragma unroll for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) { #pragma unroll for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread summary[m_idx * 2].f32[min_tile_m] = op(summary[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]); } } } } #endif } } else { #pragma unroll for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) { #if defined(__gfx936__) || defined(__gfx938__) summary_cur[m_idx * 2].u64 = summary[m_idx * 2].u64; #pragma unroll for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) { #pragma unroll for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread __float2 additem_pair = {tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2].f32[vec_idx], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + 1].f32[vec_idx]}; summary_cur[m_idx * 2].u64 = __builtin_hcu_pk_add_f32( summary_cur[m_idx * 2].u64, additem_pair ); } } } #else #pragma unroll for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) { summary_cur[m_idx * 2].f32[min_tile_m] = summary[m_idx * 2].f32[min_tile_m]; #pragma unroll for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) { #pragma unroll for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { // mmac min_tile is 16*16, a warp is 64 thread summary_cur[m_idx * 2].f32[min_tile_m] = op(summary_cur[m_idx * 2].f32[min_tile_m], tensor[m_idx + n_idx * (WARP_M / 32)][min_tile_n * 2 + min_tile_m].f32[vec_idx]); } } } } #endif } } } template __device__ inline void quad_allreduce_(DataType *dst, DataType *src, Operator &op) { #pragma unroll for (int mi = 0; mi < (WARP_M / 32); mi++) { dst[mi] = Allreduce<64>::run(src[mi], op); } } template __device__ inline void reduce_(const DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], DataType1 *summary, Operator &op, DataType1 *summary_cur=nullptr) { if constexpr (OpType == 0) { // sum if constexpr (zero_init == true) { thread_reduce_sum(tensor, summary, op); quad_allreduce_(summary, summary, op); } else { thread_reduce_sum(tensor, summary, op, summary_cur); quad_allreduce_(summary_cur, summary_cur, op); } } else if constexpr (OpType == 1) { // max if constexpr (zero_init == true) { thread_reduce_max(tensor, summary, op); quad_allreduce_(summary, summary, op); } else { thread_reduce_max(tensor, summary, op, summary_cur); quad_allreduce_(summary_cur, summary_cur, op); } } } // zero_init==true, max is current max_score, max_cur=nullptr // zero_init==true, max is prev max_score, max_cur!=nullptr template __device__ inline void reduce_max(const DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], DataType1 *max , DataType1 *max_cur=nullptr) { MaxOp max_op; if constexpr (zero_init == true) { reduce_, 1, DataType0, DataType1, WARP_M, WARP_N>(tensor, max, max_op); } else { reduce_, 1, DataType0, DataType1, WARP_M, WARP_N>(tensor, max, max_op, max_cur); } } template __device__ inline void reduce_sum(DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], DataType1 *sum, DataType1 *sum_cur=nullptr){ SumOp sum_op; if constexpr (zero_init == true) { reduce_, 0, DataType0, DataType1, WARP_M, WARP_N>(tensor, sum, sum_op); } else { reduce_, 0, DataType0, DataType1, WARP_M, WARP_N>(tensor, sum, sum_op, sum_cur); } } // Apply the exp to all the elements. template inline __device__ void scale_apply_exp2(DataType0 tensor[(WARP_M / 32) * (WARP_N / 32)][4], const DataType1 *max, const float scale) { #pragma unroll for (int mi = 0; mi < (WARP_M / 32); ++mi) { // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. // If we don't have float around M_LOG2E the multiplication is done in fp64. #pragma unroll for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) { const float max_scaled = (max[mi * 2].f32[min_tile_m] == -INFINITY) ? 0.f : (max[mi * 2].f32[min_tile_m] * (Scale_max ? scale : float(M_LOG2E))); __float2 neg_max_scaled_pair = {-max_scaled, -max_scaled}; __float2 scale_pair = {scale, scale}; #pragma unroll for (int ni = 0; ni < (WARP_N / 32); ++ni) { // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - // max * log_2(e)) This allows the compiler to use the ffma // instruction instead of fadd and fmul separately. // min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16 for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { int mmac_id = min_tile_n * 2 + min_tile_m; int qk_tile_id = mi + ni * (WARP_M / 32); #if defined(__gfx936__) || defined(__gfx938__) for(int vec_idx = 0; vec_idx < 2; ++vec_idx) { tensor[qk_tile_id][mmac_id].u64[vec_idx] = __builtin_hcu_pk_fma_f32( tensor[qk_tile_id][mmac_id].u64[vec_idx], scale_pair, neg_max_scaled_pair ); } for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { tensor[qk_tile_id][mmac_id].f32[vec_idx] = __llvm_exp2_f32(tensor[qk_tile_id][mmac_id].f32[vec_idx]); } #else for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { tensor[qk_tile_id][mmac_id].f32[vec_idx] = __llvm_exp2_f32(tensor[qk_tile_id][mmac_id].f32[vec_idx] * scale - max_scaled); } #endif } } } } } template inline __device__ void softmax_rescale_o(DataType0 scores[(WARP_N / 32) * (WARP_M / 32)][4], DataType1 *scores_max, DataType1 *scores_sum, DataType0 acc_o[(K / kBlockK) * (WARP_M / 32) * (kBlockK / 32)][4], float softmax_scale_log2) { if constexpr (Is_first) { reduce_max(scores, scores_max); scale_apply_exp2(scores, scores_max, softmax_scale_log2); reduce_sum(scores, scores_sum); } else { DataType1 scores_max_cur[(WARP_M / 32)]; reduce_max(scores, scores_max, scores_max_cur); // scores_max is prev scores max for (int mi = 0; mi < (WARP_M / 32); ++mi) { // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. // If we don't have float around M_LOG2E the multiplication is done in fp64. #pragma unroll for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) { float scores_max_cur_reg = !Check_inf ? scores_max_cur[mi * 2].f32[min_tile_m] : (scores_max_cur[mi * 2].f32[min_tile_m] == -INFINITY ? 0.0f : scores_max_cur[mi * 2].f32[min_tile_m]); float scores_scale = __llvm_exp2_f32((scores_max[mi * 2].f32[min_tile_m] - scores_max_cur_reg) * softmax_scale_log2); scores_sum[mi * 2].f32[min_tile_m] *= scores_scale; __float2 scores_scale_pair = {scores_scale, scores_scale}; #pragma unroll for(int pv_n_loop = 0; pv_n_loop < (K / kBlockK); pv_n_loop++) { #pragma unroll for (int ni = 0; ni < (kBlockK / 32); ++ni) { // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - // max * log_2(e)) This allows the compiler to use the ffma // instruction instead of fadd and fmul separately. // min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16 for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) { int pv_tile_id = pv_n_loop * (WARP_M / 32) * (kBlockK / 32) + mi + ni * (WARP_M / 32); int mmac_id = min_tile_n * 2 + min_tile_m; #if defined(__gfx936__) || defined(__gfx938__) #pragma unroll for(int vec_idx = 0; vec_idx < 2; ++vec_idx) { acc_o[pv_tile_id][min_tile_n * 2 + min_tile_m].u64[vec_idx] = __builtin_hcu_pk_mul_f32( acc_o[pv_tile_id][mmac_id].u64[vec_idx], scores_scale_pair ); } #else #pragma unroll for(int vec_idx = 0; vec_idx < 4; ++vec_idx) { acc_o[pv_tile_id][mmac_id].f32[vec_idx] *= scores_scale; } #endif } } } } } scale_apply_exp2(scores, scores_max_cur, softmax_scale_log2); DataType1 scores_sum_cur[(WARP_M / 32)]; for (int mi = 0; mi < (WARP_M / 32); ++mi) { scores_sum_cur[mi].u64 = 0x0; } reduce_sum(scores, scores_sum_cur); for (int mi = 0; mi < (WARP_M / 32); ++mi) { #if defined(__gfx936__) || defined(__gfx938__) scores_sum[mi].u64 = __builtin_hcu_pk_add_f32( scores_sum[mi].u64, scores_sum_cur[mi].u64 ); #else // for perf-model, add listed below will be optimized as v_fmac_f32, leading to incorrect results scores_sum[mi].f32[0] += scores_sum_cur[mi].f32[0]; scores_sum[mi].f32[1] += scores_sum_cur[mi].f32[1]; #endif #if defined(USE_V_MOV_B64) && (defined(__gfx936__) || defined(__gfx938__)) inlineasm_fa_v_mov_b64( scores_max[mi].u64, scores_max_cur[mi].u64 ); #else scores_max[mi].f32[0] = scores_max_cur[mi].f32[0]; scores_max[mi].f32[1] = scores_max_cur[mi].f32[1]; #endif } } }; template inline __device__ void convert_pk_type(union_vec2_f16x2 p_reg[(WARP_M / 32) * (WARP_N / 32)][4], union_vec4_fp32 s_reg[(WARP_M / 32) * (WARP_N / 32)][4]) { #pragma unroll for(int n_idx = 0; n_idx < (WARP_N / 32); ++n_idx) { #pragma unroll for(int m_idx = 0; m_idx < (WARP_M / 32); ++m_idx) { #pragma unroll for(int min_tile_m = 0; min_tile_m < 2; ++min_tile_m) { #pragma unroll for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) { #if defined(__gfx938__) p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair( s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32x2[min_tile_k]); p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16x2[min_tile_k] = DownCastPair( s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32x2[min_tile_k]); #else p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast( s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 0]); p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 0] = DownCast( s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 0]); p_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast( s_reg[n_idx * (WARP_M / 32) + m_idx][0 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]); p_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f16[min_tile_k * 2 + 1] = DownCast( s_reg[n_idx * (WARP_M / 32) + m_idx][1 * 2 + min_tile_m].f32[min_tile_k * 2 + 1]); #endif } } } } }