Commit b94fdd0f authored by zhanghj2's avatar zhanghj2
Browse files

prefill支持head 16

parent 38421051
......@@ -8,6 +8,7 @@
enum class FwdFeatures : int {
HEAD_16,
HEAD_64,
HEAD_128,
......@@ -26,6 +27,7 @@ class FwdImplBase : public ImplBase<
class Fwd_Sm90_Impl : public FwdImplBase {
DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_16,
FwdFeatures::HEAD_64,
FwdFeatures::HEAD_128,
FwdFeatures::HEAD_DIM_512,
......@@ -45,57 +47,6 @@ protected:
}
};
class Fwd_Sm100_Head64_Impl : public FwdImplBase {
DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_64,
FwdFeatures::HEAD_DIM_512,
FwdFeatures::HEAD_DIM_576,
FwdFeatures::ATTN_SINK,
FwdFeatures::SINK_LSE,
FwdFeatures::TOPK_LENGTH
)
protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
// sm100::fwd::head64::run_fwd_phase1_kernel<HEAD_DIM_QK>(params);
});
}
};
class Fwd_Sm100_Head128_Impl : public FwdImplBase {
DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_128,
FwdFeatures::HEAD_DIM_512,
FwdFeatures::HEAD_DIM_576,
FwdFeatures::ATTN_SINK,
FwdFeatures::SINK_LSE,
FwdFeatures::TOPK_LENGTH
)
protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
// sm100::fwd::head128::run_fwd_phase1_kernel<HEAD_DIM_QK>(params);
});
}
};
class Fwd_Sm100_Head128_Small_TopK_Impl : public FwdImplBase {
DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_128,
FwdFeatures::HEAD_DIM_512,
FwdFeatures::ATTN_SINK,
FwdFeatures::SINK_LSE,
FwdFeatures::TOPK_LENGTH
)
protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
// sm100::fwd_for_small_topk::head128::run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::Prefill, 512>(params);
}
};
static std::vector<at::Tensor> sparse_attn_prefill_interface(
const at::Tensor &q,
const at::Tensor &kv,
......@@ -187,7 +138,9 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
};
std::vector<FwdFeatures> required_features;
if (h_q == 64) {
if (h_q == 16) {
required_features.push_back(FwdFeatures::HEAD_16);
} else if (h_q == 64) {
required_features.push_back(FwdFeatures::HEAD_64);
} else if (h_q == 128) {
required_features.push_back(FwdFeatures::HEAD_128);
......
......@@ -64,7 +64,7 @@ if __name__ == '__main__':
TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
16, 128, 64
]
for s_kv, topk in [
# Regular shapes
......@@ -92,7 +92,7 @@ if __name__ == '__main__':
TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, have_attn_sink=have_attn_sink, have_topk_length=have_topk_length, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
16, 128, 64
]
for s_kv, topk in [
(592, 128),
......@@ -114,7 +114,7 @@ if __name__ == '__main__':
TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=True, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
16, 128, 64
]
for s_q, s_kv, topk in [
(1, 128, 128),
......@@ -150,6 +150,7 @@ if __name__ == '__main__':
(512, 64, 512, [8192, 32768, 49152, 65536]),
# MODEL1 CONFIG2
(512, 128, 1024, [8192, 32768, 49152, 65536]),
(512, 16, 1024, [8192, 32768, 49152, 65536]),
]
performance_cases = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment