Unverified Commit a03f6f8e authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Enable CUDA graphs (#386)



* Add RNG state to kernel launch params
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Save seed and offset for backward
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Single thread write to global mem
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* compute_dq_dk_dv_1colblock get seed and offset from launch params
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* compute_dq_dk_dv_1rowblock get seed and offset from launch params
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change forward c++ APIs to save RNG state for backward
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change backward c++ APIs to set RNG state for bprop launcher
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Python side API changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fix; only save seeds instead of full offset
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Account for 3D grid size
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4c98d0b4
...@@ -294,11 +294,16 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head ...@@ -294,11 +294,16 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
softmax_scale, softmax_scale,
is_causal); is_causal);
if (p_dropout > 0.0) {
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in thc random
// state // state
// We use a custom RNG that increases the offset by batch_size * nheads * 32. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32; int64_t counter_offset = params.b * params.h * 32;
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator()); gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators] // See Note [Acquire lock when using random generators]
...@@ -315,7 +320,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head ...@@ -315,7 +320,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
if (out_.has_value()) { out_.value().copy_(out); } if (out_.has_value()) { out_.value().copy_(out); }
} }
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p}; return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
} }
std::vector<at::Tensor> std::vector<at::Tensor>
...@@ -448,11 +453,16 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -448,11 +453,16 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
softmax_scale, softmax_scale,
is_causal); is_causal);
if (p_dropout > 0.0) {
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in thc random
// state // state
// We use a custom RNG that increases the offset by batch_size * nheads * 32. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32; int64_t counter_offset = params.b * params.h * 32;
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator()); gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators] // See Note [Acquire lock when using random generators]
...@@ -469,7 +479,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -469,7 +479,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
if (out_.has_value()) { out_.value().copy_(out); } if (out_.has_value()) { out_.value().copy_(out); }
} }
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p}; return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
} }
void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
...@@ -507,7 +517,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -507,7 +517,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const float softmax_scale, const float softmax_scale,
const bool is_causal, const bool is_causal,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5; // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
...@@ -669,10 +680,15 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -669,10 +680,15 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
// We use a custom RNG that increases the offset by batch_size * nheads * 32. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32; int64_t counter_offset = params.b * params.h * 32;
if (is_dropout) { if ( rng_state.has_value() ) {
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
} else if( is_dropout ) {
// See Note [Acquire lock when using random generators] // See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_); std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset); params.philox_args = gen->philox_cuda_state(counter_offset);
auto seeds = at::cuda::philox::unpack(params.philox_args);
params.rng_state[0] = std::get<0>(seeds);
params.rng_state[1] = std::get<1>(seeds);
} }
launch(params, stream, /*configure=*/false); launch(params, stream, /*configure=*/false);
...@@ -709,7 +725,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -709,7 +725,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
const bool is_causal, const bool is_causal,
c10::optional<at::Generator> gen_ c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state
) { ) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5; // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
...@@ -881,10 +898,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -881,10 +898,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
// We use a custom RNG that increases the offset by batch_size * nheads * 32. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32; int64_t counter_offset = params.b * params.h * 32;
if (is_dropout) { if ( rng_state.has_value() ) {
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
} else if( is_dropout ) {
// See Note [Acquire lock when using random generators] // See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_); std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset); params.philox_args = gen->philox_cuda_state(counter_offset);
auto seeds = at::cuda::philox::unpack(params.philox_args);
params.rng_state[0] = std::get<0>(seeds);
params.rng_state[1] = std::get<1>(seeds);
} }
launch(params, stream, /*configure=*/false); launch(params, stream, /*configure=*/false);
......
...@@ -91,6 +91,9 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -91,6 +91,9 @@ struct Flash_fwd_params : public Qkv_params {
// Random state. // Random state.
at::PhiloxCudaState philox_args; at::PhiloxCudaState philox_args;
// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t * rng_state;
bool is_bf16; bool is_bf16;
bool is_causal; bool is_causal;
}; };
......
...@@ -755,9 +755,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -755,9 +755,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view); copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view);
} }
auto seeds = at::cuda::philox::unpack(params.philox_args); auto seed = params.rng_state[0];
unsigned long long seed = std::get<0>(seeds); auto offset = params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32;
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
clear(acc_dv); clear(acc_dv);
clear(acc_dk); clear(acc_dk);
...@@ -1301,9 +1300,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1301,9 +1300,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
#pragma unroll #pragma unroll
for (int mi = 0; mi < size(dP_sum); ++mi) { dP_sum(mi) = sdPsum(get<0>(taccScS_row(mi))); } for (int mi = 0; mi < size(dP_sum); ++mi) { dP_sum(mi) = sdPsum(get<0>(taccScS_row(mi))); }
auto seeds = at::cuda::philox::unpack(params.philox_args); auto seed = params.rng_state[0];
unsigned long long seed = std::get<0>(seeds); auto offset = params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32;
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
clear(acc_dq); clear(acc_dq);
......
...@@ -130,6 +130,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -130,6 +130,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// The thread index. // The thread index.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
// The global block index.
const int block_id = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kBlockN = Kernel_traits::kBlockN;
...@@ -308,6 +310,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -308,6 +310,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
unsigned long long seed = std::get<0>(seeds); unsigned long long seed = std::get<0>(seeds);
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
// Save seed and offset for backward.
if (block_id == 0 && tidx == 0) {
params.rng_state[0] = seed;
params.rng_state[1] = std::get<1>(seeds);
}
clear(acc_o); clear(acc_o);
// For performance reason, we separate out two kinds of iterations: // For performance reason, we separate out two kinds of iterations:
......
...@@ -39,45 +39,46 @@ def _get_block_size(device, head_dim, is_dropout, is_causal): ...@@ -39,45 +39,46 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax): def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None
) )
return out, q, k, v, out_padded, softmax_lse, S_dmask return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_softmax): dropout_p, softmax_scale, causal, return_softmax):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, False, causal, return_softmax, None softmax_scale, False, causal, return_softmax, None
) )
# if out.isnan().any() or softmax_lse.isnan().any(): # if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint() # breakpoint()
return out, q, k, v, out_padded, softmax_lse, S_dmask return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
dropout_p, softmax_scale, causal): dropout_p, softmax_scale, causal, rng_state=None):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous # dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, None dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p,
softmax_scale, causal, None, rng_state
) )
return dq, dk, dv, softmax_d return dq, dk, dv, softmax_d
def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal): dropout_p, softmax_scale, causal, rng_state=None):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous # dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd( dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None, rng_state
) )
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint() # breakpoint()
...@@ -88,11 +89,9 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -88,11 +89,9 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale,
causal=causal, return_softmax=return_softmax and dropout_p > 0 causal=causal, return_softmax=return_softmax and dropout_p > 0
) )
...@@ -105,18 +104,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -105,18 +104,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, dout, *args): def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
_flash_attn_backward( _flash_attn_backward(
dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
ctx.dropout_p, ctx.softmax_scale, ctx.causal ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state
) )
dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None return dqkv, None, None, None, None
...@@ -124,11 +118,9 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -124,11 +118,9 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax): def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
) )
...@@ -142,19 +134,14 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -142,19 +134,14 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, dout, *args): def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
_flash_attn_varlen_backward( _flash_attn_varlen_backward(
dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2],
cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen, cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
ctx.dropout_p, ctx.softmax_scale, ctx.causal ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state
) )
dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None return dqkv, None, None, None, None, None, None
...@@ -162,11 +149,9 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -162,11 +149,9 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal=causal, q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal=causal,
return_softmax=return_softmax and dropout_p > 0 return_softmax=return_softmax and dropout_p > 0
) )
...@@ -179,20 +164,16 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -179,20 +164,16 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, dout, *args): def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq = torch.empty_like(q) dq = torch.empty_like(q)
kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
_flash_attn_backward( _flash_attn_backward(
dout, q, k, v, out, softmax_lse, dout, q, k, v, out, softmax_lse,
dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal,
rng_state=rng_state
) )
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., :dout.shape[-1]] dkv = dkv[..., :dout.shape[-1]]
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dkv, None, None, None, None return dq, dkv, None, None, None, None
...@@ -201,11 +182,9 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -201,11 +182,9 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, return_softmax): softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
) )
...@@ -221,21 +200,16 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -221,21 +200,16 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, dout, *args): def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq = torch.empty_like(q) dq = torch.empty_like(q)
kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
_flash_attn_varlen_backward( _flash_attn_varlen_backward(
dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1], dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1],
cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, ctx.causal ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state
) )
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., :dout.shape[-1]] dkv = dkv[..., :dout.shape[-1]]
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dkv, None, None, None, None, None, None, None, None return dq, dkv, None, None, None, None, None, None, None, None
...@@ -243,11 +217,9 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -243,11 +217,9 @@ class FlashAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal=causal, q, k, v, dropout_p, softmax_scale, causal=causal,
return_softmax=return_softmax and dropout_p > 0 return_softmax=return_softmax and dropout_p > 0
) )
...@@ -260,19 +232,15 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -260,19 +232,15 @@ class FlashAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, dout, *args): def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward( _flash_attn_backward(
dout, q, k, v, out, softmax_lse, dout, q, k, v, out, softmax_lse,
dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
rng_state=rng_state
) )
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., :dout.shape[-1]] dk = dk[..., :dout.shape[-1]]
dv = dv[..., :dout.shape[-1]] dv = dv[..., :dout.shape[-1]]
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dk, dv, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None
...@@ -281,11 +249,9 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -281,11 +249,9 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, return_softmax): softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
) )
...@@ -301,19 +267,15 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -301,19 +267,15 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, dout, *args): def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_varlen_backward( _flash_attn_varlen_backward(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
rng_state=rng_state
) )
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., :dout.shape[-1]] dk = dk[..., :dout.shape[-1]]
dv = dv[..., :dout.shape[-1]] dv = dv[..., :dout.shape[-1]]
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dk, dv, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment