Commit 1fe94628 authored by sxtyzhangzk's avatar sxtyzhangzk
Browse files

[major] migrate to nunchaku

parent 453e1bfd
......@@ -6,15 +6,27 @@
******************************************************************************/
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#define BUILD_NUNCHAKU 1
#if BUILD_NUNCHAKU
#include "pytorch_compat.h"
using namespace pytorch_compat;
#else
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#endif
#include <cutlass/numeric_types.h>
#include "flash.h"
#include "static_switch.h"
#include "src/flash.h"
#include "src/static_switch.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
......@@ -33,12 +45,12 @@ void set_params_fprop(Flash_fwd_params &params,
const size_t d,
const size_t d_rounded,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
at::Tensor q,
at::Tensor k,
at::Tensor v,
at::Tensor out,
void *cu_seqlens_q_d,
void *cu_seqlens_k_d,
const void *cu_seqlens_q_d,
const void *cu_seqlens_k_d,
void *seqused_k,
void *p_d,
void *softmax_lse_d,
......@@ -74,8 +86,8 @@ void set_params_fprop(Flash_fwd_params &params,
params.o_batch_stride = out.stride(0);
}
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
params.cu_seqlens_q = static_cast<const int *>(cu_seqlens_q_d);
params.cu_seqlens_k = static_cast<const int *>(cu_seqlens_k_d);
params.seqused_k = static_cast<int *>(seqused_k);
// P = softmax(QK^T)
......@@ -139,7 +151,7 @@ void set_params_dgrad(Flash_bwd_params &params,
const at::Tensor k,
const at::Tensor v,
const at::Tensor out,
const at::Tensor dout,
at::Tensor dout,
at::Tensor dq,
at::Tensor dk,
at::Tensor dv,
......@@ -206,7 +218,8 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
} else {
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
// run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
assert(false);
}
});
});
......@@ -217,6 +230,8 @@ void run_mha_fwd_block(Flash_fwd_params &params, cudaStream_t stream, bool force
FWD_BLOCK_HEADDIM_SWITCH(params.d, [&] {
if (params.num_splits <= 1 && !force_split_kernel) {
run_mha_fwd_block_<elem_type, kHeadDim>(params, stream);
} else {
assert(false);
}
});
});
......@@ -423,11 +438,15 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
if (p_dropout > 0.0) {
#if BUILD_NUNCHAKU
assert(false);
#else
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
#endif
}
if (alibi_slopes_.has_value()) {
......@@ -471,8 +490,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
at::Tensor &cu_seqlens_q, // b+1
at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q,
......@@ -621,11 +640,15 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
if (p_dropout > 0.0) {
#if BUILD_NUNCHAKU
assert(false);
#else
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
#endif
}
if (alibi_slopes_.has_value()) {
......@@ -658,6 +681,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
}
#if !BUILD_NUNCHAKU
void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
FP16_SWITCH(!params.is_bf16, [&] {
if (params.d <= 32) {
......@@ -1440,6 +1465,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
return {out, softmax_lse};
}
#endif
// add by JXGuo
const int SPARSE_SIZE = 128;
......@@ -1455,7 +1482,7 @@ mha_fwd_block(const at::Tensor &q,
// const bool is_blocksparse,
const int m_block_dim,
const int n_block_dim,
const at::Tensor &head_mask_type, // (num_heads)
at::Tensor &head_mask_type, // (num_heads)
c10::optional<at::Tensor> &streaming_info_, // (num_heads, 2)
c10::optional<at::Tensor> &row_blockmask_, // (batch_size, num_blocksparse_heads, seqlen_m / m_block_dim, seqlen_n / n_block_dim)
const int max_seqlen_q_,
......@@ -1651,11 +1678,15 @@ mha_fwd_block(const at::Tensor &q,
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
if (p_dropout > 0.0) {
#if BUILD_NUNCHAKU
assert(false);
#else
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
#endif
}
auto stream = at::cuda::getCurrentCUDAStream().stream();
......@@ -1670,6 +1701,8 @@ mha_fwd_block(const at::Tensor &q,
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
}
#if !BUILD_NUNCHAKU
std::vector<at::Tensor>
mha_bwd_block(const at::Tensor &dout, // total_q x num_heads, x head_size
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
......@@ -1992,6 +2025,10 @@ mha_bwd_block(const at::Tensor &dout, // total_q x num_heads, x head_size
return { dq, dk, dv, softmax_d };
}
#endif
#if !BUILD_NUNCHAKU
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass");
......@@ -2003,3 +2040,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fwd_block", &mha_fwd_block, "Forward pass, with blockmask");
m.def("bwd_block", &mha_bwd_block, "Forward pass, with blockmask");
}
#endif
#pragma once
#include "Tensor.h"
std::vector<Tensor>
mha_fwd(Tensor &q, // batch_size x seqlen_q x num_heads x head_size
Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
// c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
// c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const bool return_softmax
// c10::optional<at::Generator> gen_
);
std::vector<Tensor>
mha_varlen_fwd(Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
// std::optional<Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
Tensor &cu_seqlens_q, // b+1
Tensor &cu_seqlens_k, // b+1
// std::optional<Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
// std::optional<Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
// std::optional<Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const bool return_softmax);
std::vector<Tensor>
mha_fwd_block(const Tensor &q, // total_q x num_heads x head_size, total := \sum_{i=0}^{b} s_i
const Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const Tensor &cu_seqlens_q, // b+1
const Tensor &cu_seqlens_k, // b+1
const int m_block_dim,
const int n_block_dim,
Tensor &head_mask_type, // (num_heads)
std::optional<Tensor> streaming_info_, // (num_heads, 2)
std::optional<Tensor> row_blockmask_, // (batch_size, num_blocksparse_heads, max_seqlen_m / m_block_dim, k)
const int max_seqlen_q_,
const int max_seqlen_k_,
const float p_dropout,
const float softmax_scale,
const bool is_causal,
const bool exact_streaming,
const bool return_softmax,
int window_size_left,
int window_size_right);
\ No newline at end of file
#include "flash_api.h"
#include "pytorch_compat.h"
using namespace pytorch_compat;
std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const bool return_softmax,
c10::optional<at::Generator> gen_);
std::vector<Tensor>
mha_fwd(Tensor &q, // batch_size x seqlen_q x num_heads x head_size
Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
// c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
// c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const bool return_softmax
// c10::optional<at::Generator> gen_
)
{
std::optional<Tensor> out_ = {};
std::optional<Tensor> alibi_slopes_ = {};
return mha_fwd(
q, k, v,
out_, alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
return_softmax,
{}
);
}
std::vector<at::Tensor>
mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
at::Tensor &cu_seqlens_q, // b+1
at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
int window_size_left,
int window_size_right,
const bool return_softmax,
c10::optional<at::Generator> gen_);
std::vector<Tensor>
mha_varlen_fwd(Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
// std::optional<Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
Tensor &cu_seqlens_q, // b+1
Tensor &cu_seqlens_k, // b+1
// std::optional<Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
// std::optional<Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
// std::optional<Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const bool return_softmax)
{
std::optional<Tensor> out_ = {};
std::optional<Tensor> seqused_k = {};
std::optional<Tensor> alibi_slopes_ = {};
return mha_varlen_fwd(
q, k, v,
out_,
cu_seqlens_q, cu_seqlens_k,
seqused_k, alibi_slopes_,
max_seqlen_q, max_seqlen_k,
p_dropout, softmax_scale, zero_tensors, is_causal,
window_size_left, window_size_right,
return_softmax,
{}
);
}
std::vector<at::Tensor>
mha_fwd_block(const at::Tensor &q,
// total_q x num_heads x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &k,
// total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v,
// total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
const int m_block_dim,
const int n_block_dim,
at::Tensor &head_mask_type, // (num_heads)
c10::optional<at::Tensor> &streaming_info_, // (num_heads, 2)
c10::optional<at::Tensor> &row_blockmask_, // (batch_size, num_blocksparse_heads, max_seqlen_m / m_block_dim, k)
const int max_seqlen_q_,
const int max_seqlen_k_,
const float p_dropout,
const float softmax_scale,
const bool is_causal,
const bool exact_streaming,
const bool return_softmax,
int window_size_left,
int window_size_right,
c10::optional<at::Generator> gen_);
std::vector<Tensor>
mha_fwd_block(const Tensor &q, // total_q x num_heads x head_size, total := \sum_{i=0}^{b} s_i
const Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const Tensor &cu_seqlens_q, // b+1
const Tensor &cu_seqlens_k, // b+1
const int m_block_dim,
const int n_block_dim,
Tensor &head_mask_type, // (num_heads)
std::optional<Tensor> streaming_info_, // (num_heads, 2)
std::optional<Tensor> row_blockmask_, // (batch_size, num_blocksparse_heads, max_seqlen_m / m_block_dim, k)
const int max_seqlen_q_,
const int max_seqlen_k_,
const float p_dropout,
const float softmax_scale,
const bool is_causal,
const bool exact_streaming,
const bool return_softmax,
int window_size_left,
int window_size_right)
{
return mha_fwd_block(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type, streaming_info_, row_blockmask_,
max_seqlen_q_, max_seqlen_k_,
p_dropout, softmax_scale, is_causal, exact_streaming, return_softmax,
window_size_left, window_size_right,
{}
);
}
\ No newline at end of file
......@@ -7,9 +7,20 @@
#pragma once
#define BUILD_NUNCHAKU 1
#include <cuda.h>
#include <vector>
#if BUILD_NUNCHAKU
namespace pytorch_compat::at {
struct PhiloxCudaState {};
}
using namespace pytorch_compat;
#else
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
......@@ -18,6 +29,8 @@
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
#endif
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
......@@ -77,8 +90,8 @@ struct Flash_fwd_params : public Qkv_params {
float scale_softmax_log2;
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
const int * __restrict__ cu_seqlens_q;
const int * __restrict__ cu_seqlens_k;
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;
......
......@@ -22,6 +22,15 @@
#include "flash_blockmask.h"
#if BUILD_NUNCHAKU
namespace pytorch_compat::at::cuda::philox {
constexpr std::tuple<int64_t, int64_t> unpack(at::PhiloxCudaState state) {
return {0, 0};
}
}
using namespace pytorch_compat;
#endif
namespace flash {
using namespace cute;
......
......@@ -7,8 +7,27 @@
#pragma once
#if BUILD_NUNCHAKU
#include "common.h"
inline void C10_CUDA_KERNEL_LAUNCH_CHECK() {
checkCUDA(cudaPeekAtLastError());
}
namespace pytorch_compat::at::cuda {
using ::getCurrentDeviceProperties;
}
#define C10_CUDA_CHECK checkCUDA
using namespace pytorch_compat;
#else
#include <ATen/cuda/CUDAContext.h>
#endif
#include "static_switch.h"
#include "flash.h"
#include "flash_fwd_kernel.h"
......
......@@ -36,6 +36,32 @@
} \
}()
#if BUILD_NUNCHAKU
#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM <= 64) { \
constexpr static int kHeadDim = 64; \
return __VA_ARGS__(); \
} if (HEADDIM <= 128) { \
constexpr static int kHeadDim = 128; \
return __VA_ARGS__(); \
} \
}()
#define FWD_BLOCK_HEADDIM_SWITCH(HEADDIM, ...)\
[&] { \
if (HEADDIM <= 64) { \
constexpr static int kHeadDim = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 128) { \
constexpr static int kHeadDim = 128; \
return __VA_ARGS__(); \
} \
}()
#else
#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM <= 32) { \
......@@ -79,6 +105,8 @@
} \
}()
#endif
#define BWD_BLOCK_HEADDIM_SWITCH(HEADDIM, ...)\
[&] { \
if (HEADDIM <= 32) { \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment