Commit 8a69b46c authored by zhanghj2's avatar zhanghj2
Browse files

Merge branch 'feature/bmz-das-prefill-glm5-nheads64' into 'master'

优化 nmz和bmz dsa prefill,nhead=64

See merge request dcutoolkit/deeplearing/flashmla!6
parents a9ef79c6 14b2cfc5
......@@ -124,5 +124,35 @@ static void run(const SparseAttnFwdParams &params);
};
template<int D_QK, bool HAVE_TOPK_LENGTH, bool IS_TOPK_2048>
class KernelTemplate_B_H_64
{
public:
static constexpr int D_Q = D_QK;
static constexpr int D_K = D_QK;
static constexpr int D_V = 512;
static constexpr int kNWarps = 4;
static constexpr int B_H = 64;
static constexpr int B_TOPK = 64; // TopK block size
static constexpr int NUM_THREADS = kNWarps * 64;
static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits)
using Element = cutlass::bfloat16_t;
using elem_type = Element;
using ElementAccum = float;
using index_t = int64_t;
static constexpr int kBlockM = B_H;
static constexpr int kBlockN = B_TOPK;
static constexpr int kHeadDim = D_QK;
static constexpr int kHeadDimV = D_V;
static __device__ __forceinline__ void
devfunc(const SparseAttnFwdParams &params);
static void run(const SparseAttnFwdParams &params);
};
};
This diff is collapsed.
......@@ -602,6 +602,95 @@ struct Softmax {
}
return lse;
};
template<bool Is_first, bool Check_inf=false, typename Tensor0>
__forceinline__ __device__ void softmax_rescale_o_prefill_4x1(Tensor0& scores, v4f* acc_o, float softmax_scale_log2) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp<float> max_op;
// Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if constexpr(Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !true
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float scores_scale = __builtin_amdgcn_exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#endif
// if (blockIdx.x == 0 && threadIdx.x == 0)
// {
// printf("threadIdx.x %.2f, scores_scale = %.4f\n",row_sum(mi), scores_scale );
// }
row_sum(mi) *= scores_scale;
for (int i = 0; i < 32; i++)
{
acc_o[i].x *= scores_scale;
acc_o[i].y *= scores_scale;
acc_o[i].z *= scores_scale;
acc_o[i].w *= scores_scale;
}
}
// if (blockIdx.x == 2)
// {
// printf("threadIdx.x %.2f \n",row_sum(mi) );
// }
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
// if (thread0())
// {
// printf("max sum %.3f %.3f \n", row_max(0), row_sum(0));
// }
};
template<bool Is_dropout=false, bool Split=false>
__forceinline__ __device__ TensorT normalize_softmax_lse_prefill_4x1(v4f *acc_o, float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
// flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op);
TensorT lse = make_fragment_like(row_sum);
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
// if (thread0())
// {
// printf(" %.3f %.3f \n", row_max(0), row_sum(0));
// }
#pragma unroll
for (int mi = 0; mi < 1; ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
for (int i = 0; i < 32; i++)
{
acc_o[i].x *= scale;
acc_o[i].y *= scale;
acc_o[i].z *= scale;
acc_o[i].w *= scale;
}
}
return lse;
};
};
......
......@@ -1523,6 +1523,91 @@ __forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor
}
#endif
typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8)));
template<typename Element, int k_idx>
__forceinline__ __device__ void qk_gemm(const __fp16x8_t& q_data, Element* k_lds_read_ptr, v4f* accs_f32)
{
typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8)));
typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4)));
union Bf16_storage {
__fp16x8_t data_128;
__fp16x4_t data_64[2];
uint16_t data_array[8];
};
constexpr int k_idx_even = k_idx % 4;
constexpr int n_offset = 16 * 32;
constexpr int k_offset = k_idx_even * 64 * 32;
Bf16_storage q_reg;
Bf16_storage k_reg;
q_reg.data_128 = q_data;
k_reg.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_read_ptr + k_offset);
// q_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(q_lds_read_ptr), k_offset, 2, 1, 0);
// k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 0 * n_offset + k_offset, 2, 1, 0);
#if defined(__gfx938__)
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[0], k_reg.data_64[0], accs_f32[0], true,false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[1], k_reg.data_64[1], accs_f32[0], true,false);
#else
accs_f32[0] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[0], k_reg.data_64[0], accs_f32[0]);
accs_f32[0] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[1], k_reg.data_64[1], accs_f32[0]);
#endif
k_reg.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_read_ptr + k_offset + 1 * n_offset);
#if defined(__gfx938__)
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[0], k_reg.data_64[0], accs_f32[1], true,false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[1], k_reg.data_64[1], accs_f32[1], true,false);
#else
accs_f32[1] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[0], k_reg.data_64[0], accs_f32[1]);
accs_f32[1] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[1], k_reg.data_64[1], accs_f32[1]);
#endif
// k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 1 * n_offset + k_offset, 2, 1, 0);
k_reg.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_read_ptr + k_offset + 2 * n_offset);
// k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 2 * n_offset + k_offset, 2, 1, 0);
#if defined(__gfx938__)
accs_f32[2] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[0], k_reg.data_64[0], accs_f32[2], true,false);
accs_f32[2] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[1], k_reg.data_64[1], accs_f32[2], true,false);
#else
accs_f32[2] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[0], k_reg.data_64[0], accs_f32[2]);
accs_f32[2] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[1], k_reg.data_64[1], accs_f32[2]);
#endif
k_reg.data_128 = *reinterpret_cast<__fp16x8_t*>(k_lds_read_ptr + k_offset + 3 * n_offset);
// k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 3 * n_offset + k_offset, 2, 1, 0);
#if defined(__gfx938__)
accs_f32[3] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[0], k_reg.data_64[0], accs_f32[3], true,false);
accs_f32[3] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(q_reg.data_64[1], k_reg.data_64[1], accs_f32[3], true,false);
#else
accs_f32[3] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[0], k_reg.data_64[0], accs_f32[3]);
accs_f32[3] = __builtin_amdgcn_mmac_f32_16x16x16bf16(q_reg.data_64[1], k_reg.data_64[1], accs_f32[3]);
#endif
}
typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4)));
template<int k_idx, int n_idx_val>
__forceinline__ __device__ void pv_gemm(const __fp16x4_t& p, int v_lds_read_ptr, v4f* acco_f32)
{
constexpr int k_idx_even = k_idx % 1;
constexpr int n_offset = 16 * 32 * 2;
typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8)));
union Bf16_storage {
__fp16x8_t data_128;
__fp16x4_t data_64[2];
uint16_t data_array[8];
};
constexpr int k_offset = k_idx_even * 16 * 512 * 2;
// #if 1
Bf16_storage v_reg;
v_reg.data_128 = __builtin_amdgcn_ds_read_m32x16f16_alt((__attribute__((address_space(3))) __fp16*)(v_lds_read_ptr), k_offset + n_idx_val * n_offset);
#if defined(__gfx938__)
acco_f32[n_idx_val * 2] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(p, v_reg.data_64[0], acco_f32[n_idx_val * 2], true, false);
acco_f32[n_idx_val * 2 + 1] = __builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(p, v_reg.data_64[1], acco_f32[n_idx_val * 2 + 1], true, false);
#else
acco_f32[n_idx_val * 2] = __builtin_amdgcn_mmac_f32_16x16x16bf16(p, v_reg.data_64[0], acco_f32[n_idx_val * 2]);
acco_f32[n_idx_val * 2 + 1] = __builtin_amdgcn_mmac_f32_16x16x16bf16(p, v_reg.data_64[1], acco_f32[n_idx_val * 2 + 1]);
#endif
}
}
\ No newline at end of file
......@@ -77,7 +77,7 @@ if __name__ == '__main__':
(1840, 256),
(1592, 384),
(1521, 512),
(3000, 2048),
# Irregular shapes with OOB TopK
(95, 128),
(153, 256),
......@@ -146,6 +146,7 @@ if __name__ == '__main__':
performance_case_templates = [
# V3.2
(576, 128, 2048, [8192, 32768, 65536, 98304, 131072]),
(576, 64, 2048, [8192, 32768, 65536, 98304, 131072]),
# MODEL1 CONFIG1
(512, 64, 512, [8192, 32768, 49152, 65536]),
# MODEL1 CONFIG2
......@@ -154,9 +155,10 @@ if __name__ == '__main__':
]
performance_cases = [
TestParam(s_q, s_kv, topk, h_q=h_q, d_qk=d_qk, have_attn_sink=True)
TestParam(s_q, s_kv, topk, h_q=h_q, d_qk=d_qk, have_attn_sink=have_attn_sink)
for (d_qk, h_q, topk, s_kv_list) in performance_case_templates
for s_q in [4096]
for have_attn_sink in [False, True]
for s_kv in s_kv_list
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment