Unverified Commit 41b611f7 authored by Zeyu WANG's avatar Zeyu WANG Committed by GitHub
Browse files

Add more GPU architctures support (#76)



* Add more GPU architctures support

* Merge fmha and mla runner

* add varlen & non varlen support, and add incontiguous tensor support

* update readme

* add varlen api

---------
Co-authored-by: default avatardianzhangc <dianzhangc@nvidia.com>
parent 9edee0c0
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
#include <torch/python.h>
void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,
at::Tensor o, at::Tensor lse,
int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen);
void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k,
at::Tensor v, at::Tensor o, at::Tensor lse,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,
at::Tensor dq, at::Tensor dk, at::Tensor dv,
int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fwd", &FMHACutlassSM100FwdRun);
m.def("bwd", &FMHACutlassSM100BwdRun);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment