Commit b94fdd0f authored by zhanghj2's avatar zhanghj2
Browse files

prefill支持head 16

parent 38421051
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
enum class FwdFeatures : int { enum class FwdFeatures : int {
HEAD_16,
HEAD_64, HEAD_64,
HEAD_128, HEAD_128,
...@@ -26,6 +27,7 @@ class FwdImplBase : public ImplBase< ...@@ -26,6 +27,7 @@ class FwdImplBase : public ImplBase<
class Fwd_Sm90_Impl : public FwdImplBase { class Fwd_Sm90_Impl : public FwdImplBase {
DECLARE_SUPPORTED_FEATURES( DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_16,
FwdFeatures::HEAD_64, FwdFeatures::HEAD_64,
FwdFeatures::HEAD_128, FwdFeatures::HEAD_128,
FwdFeatures::HEAD_DIM_512, FwdFeatures::HEAD_DIM_512,
...@@ -45,57 +47,6 @@ protected: ...@@ -45,57 +47,6 @@ protected:
} }
}; };
class Fwd_Sm100_Head64_Impl : public FwdImplBase {
DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_64,
FwdFeatures::HEAD_DIM_512,
FwdFeatures::HEAD_DIM_576,
FwdFeatures::ATTN_SINK,
FwdFeatures::SINK_LSE,
FwdFeatures::TOPK_LENGTH
)
protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
// sm100::fwd::head64::run_fwd_phase1_kernel<HEAD_DIM_QK>(params);
});
}
};
class Fwd_Sm100_Head128_Impl : public FwdImplBase {
DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_128,
FwdFeatures::HEAD_DIM_512,
FwdFeatures::HEAD_DIM_576,
FwdFeatures::ATTN_SINK,
FwdFeatures::SINK_LSE,
FwdFeatures::TOPK_LENGTH
)
protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
// sm100::fwd::head128::run_fwd_phase1_kernel<HEAD_DIM_QK>(params);
});
}
};
class Fwd_Sm100_Head128_Small_TopK_Impl : public FwdImplBase {
DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_128,
FwdFeatures::HEAD_DIM_512,
FwdFeatures::ATTN_SINK,
FwdFeatures::SINK_LSE,
FwdFeatures::TOPK_LENGTH
)
protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
// sm100::fwd_for_small_topk::head128::run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::Prefill, 512>(params);
}
};
static std::vector<at::Tensor> sparse_attn_prefill_interface( static std::vector<at::Tensor> sparse_attn_prefill_interface(
const at::Tensor &q, const at::Tensor &q,
const at::Tensor &kv, const at::Tensor &kv,
...@@ -187,7 +138,9 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface( ...@@ -187,7 +138,9 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
}; };
std::vector<FwdFeatures> required_features; std::vector<FwdFeatures> required_features;
if (h_q == 64) { if (h_q == 16) {
required_features.push_back(FwdFeatures::HEAD_16);
} else if (h_q == 64) {
required_features.push_back(FwdFeatures::HEAD_64); required_features.push_back(FwdFeatures::HEAD_64);
} else if (h_q == 128) { } else if (h_q == 128) {
required_features.push_back(FwdFeatures::HEAD_128); required_features.push_back(FwdFeatures::HEAD_128);
......
...@@ -64,7 +64,7 @@ if __name__ == '__main__': ...@@ -64,7 +64,7 @@ if __name__ == '__main__':
TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, d_qk=d_qk) TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, d_qk=d_qk)
for d_qk in [512, 576] for d_qk in [512, 576]
for h_q in [ for h_q in [
128, 64 16, 128, 64
] ]
for s_kv, topk in [ for s_kv, topk in [
# Regular shapes # Regular shapes
...@@ -92,7 +92,7 @@ if __name__ == '__main__': ...@@ -92,7 +92,7 @@ if __name__ == '__main__':
TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, have_attn_sink=have_attn_sink, have_topk_length=have_topk_length, d_qk=d_qk) TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, have_attn_sink=have_attn_sink, have_topk_length=have_topk_length, d_qk=d_qk)
for d_qk in [512, 576] for d_qk in [512, 576]
for h_q in [ for h_q in [
128, 64 16, 128, 64
] ]
for s_kv, topk in [ for s_kv, topk in [
(592, 128), (592, 128),
...@@ -114,7 +114,7 @@ if __name__ == '__main__': ...@@ -114,7 +114,7 @@ if __name__ == '__main__':
TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=True, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk) TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=True, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)
for d_qk in [512, 576] for d_qk in [512, 576]
for h_q in [ for h_q in [
128, 64 16, 128, 64
] ]
for s_q, s_kv, topk in [ for s_q, s_kv, topk in [
(1, 128, 128), (1, 128, 128),
...@@ -150,6 +150,7 @@ if __name__ == '__main__': ...@@ -150,6 +150,7 @@ if __name__ == '__main__':
(512, 64, 512, [8192, 32768, 49152, 65536]), (512, 64, 512, [8192, 32768, 49152, 65536]),
# MODEL1 CONFIG2 # MODEL1 CONFIG2
(512, 128, 1024, [8192, 32768, 49152, 65536]), (512, 128, 1024, [8192, 32768, 49152, 65536]),
(512, 16, 1024, [8192, 32768, 49152, 65536]),
] ]
performance_cases = [ performance_cases = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment