Commit a43fbbf1 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Merge remote-tracking branch 'tri/main'

parents 498cd8c3 85881f54
...@@ -44,7 +44,7 @@ jobs: ...@@ -44,7 +44,7 @@ jobs:
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux. # manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04] os: [ubuntu-20.04]
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240105'] torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240207']
cuda-version: ['11.8.0', '12.2.2'] cuda-version: ['11.8.0', '12.2.2']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
...@@ -63,7 +63,7 @@ jobs: ...@@ -63,7 +63,7 @@ jobs:
python-version: '3.7' python-version: '3.7'
- torch-version: '2.2.0' - torch-version: '2.2.0'
python-version: '3.7' python-version: '3.7'
- torch-version: '2.3.0.dev20240105' - torch-version: '2.3.0.dev20240207'
python-version: '3.7' python-version: '3.7'
# Pytorch <= 2.0 only supports CUDA <= 11.8 # Pytorch <= 2.0 only supports CUDA <= 11.8
- torch-version: '1.12.1' - torch-version: '1.12.1'
......
...@@ -205,7 +205,8 @@ void set_params_splitkv(Flash_fwd_params &params, const int batch_size, ...@@ -205,7 +205,8 @@ void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
params.num_splits = num_splits; params.num_splits = num_splits;
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
if (num_splits < 1) { if (num_splits < 1) {
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128); // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128);
} }
if (params.num_splits > 1) { if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
...@@ -295,8 +296,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -295,8 +296,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza // H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
seqlen_q = ngroups; seqlen_q = ngroups;
num_heads = num_heads_k; num_heads = num_heads_k;
...@@ -323,7 +324,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -323,7 +324,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out); CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
}
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
} else { } else {
out = torch::empty_like(q_padded); out = torch::empty_like(q_padded);
...@@ -494,8 +498,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -494,8 +498,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza // H/t Daniel Haziza
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og}); q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
max_seqlen_q = ngroups; max_seqlen_q = ngroups;
num_heads = num_heads_k; num_heads = num_heads_k;
...@@ -550,6 +554,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -550,6 +554,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
CHECK_DEVICE(out); CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, total_q, num_heads, head_size_og); CHECK_SHAPE(out, total_q, num_heads, head_size_og);
CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
}
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
} else { } else {
out = torch::empty_like(q_padded); out = torch::empty_like(q_padded);
......
...@@ -68,14 +68,16 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -68,14 +68,16 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
// Otherwise we might read OOB elements from gK and gV. // Otherwise we might read OOB elements from gK and gV.
if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; make_shape(binfo.actual_seqlen_q, params.h, params.d),
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o), make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{})); make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse), Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
Shape<Int<kBlockM>>{}, Stride<_1>{}); make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
...@@ -108,25 +110,27 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -108,25 +110,27 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max). // might save us 1 register (we just need n_block instead of both n_block and n_block_max).
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
// We move K and V to the last block.
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
+ m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q), Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr)
Shape<Int<kBlockM>, Int<kHeadDim>>{}, + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_stride(params.q_row_stride, _1{})); make_shape(binfo.actual_seqlen_q, params.h, params.d),
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k), make_stride(params.q_row_stride, params.q_head_stride, _1{}));
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{})); make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v), Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr)
Shape<Int<kBlockN>, Int<kHeadDim>>{}, + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)),
make_stride(params.v_row_stride, _1{})); make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
make_stride(params.k_row_stride, params.k_head_stride, _1{}));
Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN)
Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr)
+ binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)),
make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
make_stride(params.v_row_stride, params.v_head_stride, _1{}));
Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN)
Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p), Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
Shape<Int<kBlockM>, Int<kBlockN>>{}, Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.seqlen_k_rounded, _1{})); make_stride(params.seqlen_k_rounded, _1{}));
...@@ -146,9 +150,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -146,9 +150,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN)
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
typename Kernel_traits::TiledMma tiled_mma; typename Kernel_traits::TiledMma tiled_mma;
...@@ -241,7 +245,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -241,7 +245,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
int n_block = n_block_max - 1; int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN); binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence(); cute::cp_async_fence();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
...@@ -282,12 +286,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -282,12 +286,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Advance gV // Advance gV
if (masking_step > 0) { if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
} else { } else {
// Clear the smem tiles to account for predicated off loads // Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
} }
cute::cp_async_fence(); cute::cp_async_fence();
...@@ -305,9 +308,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -305,9 +308,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (n_block > n_block_min) { if (n_block > n_block_min) {
// Advance gK flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
...@@ -355,9 +356,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -355,9 +356,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
clear(acc_s); clear(acc_s);
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
// Advance gV flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence(); cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>( flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
...@@ -368,9 +367,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -368,9 +367,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (n_block > n_block_min) { if (n_block > n_block_min) {
// Advance gK flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
...@@ -422,14 +419,16 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -422,14 +419,16 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; make_shape(binfo.actual_seqlen_q, params.h, params.d),
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o), make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{})); make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse), Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
Shape<Int<kBlockM>>{}, Stride<_1>{}); make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
...@@ -556,8 +555,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -556,8 +555,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max). // might save us 1 register (we just need n_block instead of both n_block and n_block_max).
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
// We move K and V to the last block. // We move K and V to the last block.
const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride;
...@@ -573,9 +570,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -573,9 +570,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q), Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
Shape<Int<kBlockM>, Int<kHeadDim>>{}, make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.q_row_stride, _1{})); make_stride(params.q_row_stride, params.q_head_stride, _1{}));
Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k), Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{})); make_stride(params.k_row_stride, _1{}));
...@@ -1051,8 +1050,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1051,8 +1050,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
); );
// __syncthreads();
// if (cute::thread0()) { print(tOgOaccum); }
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -200,6 +200,11 @@ if not SKIP_CUDA_BUILD: ...@@ -200,6 +200,11 @@ if not SKIP_CUDA_BUILD:
# "--ptxas-options=-v", # "--ptxas-options=-v",
# "--ptxas-options=-O2", # "--ptxas-options=-O2",
# "-lineinfo", # "-lineinfo",
# "-DFLASHATTENTION_DISABLE_BACKWARD",
# "-DFLASHATTENTION_DISABLE_DROPOUT",
# "-DFLASHATTENTION_DISABLE_ALIBI",
# "-DFLASHATTENTION_DISABLE_UNEVEN_K",
# "-DFLASHATTENTION_DISABLE_LOCAL",
] ]
+ generator_flag + generator_flag
+ cc_flag + cc_flag
......
...@@ -252,3 +252,40 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of ...@@ -252,3 +252,40 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item() atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol) assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol)
def test_compilation_count():
batch_size = 1
headdim = 128
device = "cuda"
dtype = torch.float16
torch.manual_seed(42)
from triton.runtime.jit import JITFunction
from flash_attn.ops.triton.rotary import rotary_kernel
compilation_count = 0
def count_compilations(*args, **kwargs):
nonlocal compilation_count
compilation_count += 1
old_cache_func = JITFunction.cache_hook
try:
rotary_kernel.cache.clear()
JITFunction.cache_hook = count_compilations
for seqlen in (128, 256):
for nheads in (4, 32):
x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device)
x.requires_grad_()
cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
out = apply_rotary_emb(x, cos, sin)
out.backward(torch.randn_like(out))
# Only two kernels are expected to be compiled:
# * for the forward pass (conjugate=False)
# * for the backward pass (conjugate=True)
assert compilation_count == 2
finally:
JITFunction.cache_hook = old_cache_func
...@@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr ...@@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention # Install FlashAttention
RUN pip install flash-attn==2.5.6 RUN pip install flash-attn==2.5.7
# Install CUDA extensions for fused dense # Install CUDA extensions for fused dense
RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.5.6#subdirectory=csrc/fused_dense_lib RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.5.7#subdirectory=csrc/fused_dense_lib
__version__ = "2.5.6" __version__ = "2.5.7"
from vllm_flash_attn.flash_attn_interface import ( from vllm_flash_attn.flash_attn_interface import (
flash_attn_func, flash_attn_func,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment