Commit a43fbbf1 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Merge remote-tracking branch 'tri/main'

parents 498cd8c3 85881f54
......@@ -44,7 +44,7 @@ jobs:
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04]
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240105']
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240207']
cuda-version: ['11.8.0', '12.2.2']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
......@@ -63,7 +63,7 @@ jobs:
python-version: '3.7'
- torch-version: '2.2.0'
python-version: '3.7'
- torch-version: '2.3.0.dev20240105'
- torch-version: '2.3.0.dev20240207'
python-version: '3.7'
# Pytorch <= 2.0 only supports CUDA <= 11.8
- torch-version: '1.12.1'
......
......@@ -205,7 +205,8 @@ void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
params.num_splits = num_splits;
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
if (num_splits < 1) {
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
// We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128);
}
if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
......@@ -295,8 +296,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
seqlen_q = ngroups;
num_heads = num_heads_k;
......@@ -323,7 +324,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
}
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
} else {
out = torch::empty_like(q_padded);
......@@ -494,8 +498,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
max_seqlen_q = ngroups;
num_heads = num_heads_k;
......@@ -550,6 +554,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
}
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
} else {
out = torch::empty_like(q_padded);
......
......@@ -68,14 +68,16 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
// Otherwise we might read OOB elements from gK and gV.
if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
+ binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
......@@ -108,25 +110,27 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
// We move K and V to the last block.
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
+ m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{}));
Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr)
+ binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.q_row_stride, params.q_head_stride, _1{}));
Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr)
+ binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)),
make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
make_stride(params.k_row_stride, params.k_head_stride, _1{}));
Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN)
Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr)
+ binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)),
make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
make_stride(params.v_row_stride, params.v_head_stride, _1{}));
Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN)
Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.seqlen_k_rounded, _1{}));
......@@ -146,9 +150,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN)
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
typename Kernel_traits::TiledMma tiled_mma;
......@@ -241,7 +245,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
......@@ -282,12 +286,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Advance gV
if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
} else {
// Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
cute::cp_async_fence();
......@@ -305,9 +308,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
flash::cp_async_wait<0>();
__syncthreads();
if (n_block > n_block_min) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
......@@ -355,9 +356,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
clear(acc_s);
flash::cp_async_wait<0>();
__syncthreads();
// Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
......@@ -368,9 +367,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
flash::cp_async_wait<0>();
__syncthreads();
if (n_block > n_block_min) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
......@@ -422,14 +419,16 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
+ binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
......@@ -556,8 +555,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
// We move K and V to the last block.
const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride;
......@@ -573,9 +570,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.q_row_stride, params.q_head_stride, _1{}));
Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
......@@ -1051,8 +1050,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
// __syncthreads();
// if (cute::thread0()) { print(tOgOaccum); }
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -200,6 +200,11 @@ if not SKIP_CUDA_BUILD:
# "--ptxas-options=-v",
# "--ptxas-options=-O2",
# "-lineinfo",
# "-DFLASHATTENTION_DISABLE_BACKWARD",
# "-DFLASHATTENTION_DISABLE_DROPOUT",
# "-DFLASHATTENTION_DISABLE_ALIBI",
# "-DFLASHATTENTION_DISABLE_UNEVEN_K",
# "-DFLASHATTENTION_DISABLE_LOCAL",
]
+ generator_flag
+ cc_flag
......
......@@ -252,3 +252,40 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol)
def test_compilation_count():
batch_size = 1
headdim = 128
device = "cuda"
dtype = torch.float16
torch.manual_seed(42)
from triton.runtime.jit import JITFunction
from flash_attn.ops.triton.rotary import rotary_kernel
compilation_count = 0
def count_compilations(*args, **kwargs):
nonlocal compilation_count
compilation_count += 1
old_cache_func = JITFunction.cache_hook
try:
rotary_kernel.cache.clear()
JITFunction.cache_hook = count_compilations
for seqlen in (128, 256):
for nheads in (4, 32):
x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device)
x.requires_grad_()
cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
out = apply_rotary_emb(x, cos, sin)
out.backward(torch.randn_like(out))
# Only two kernels are expected to be compiled:
# * for the forward pass (conjugate=False)
# * for the backward pass (conjugate=True)
assert compilation_count == 2
finally:
JITFunction.cache_hook = old_cache_func
......@@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
RUN pip install flash-attn==2.5.6
RUN pip install flash-attn==2.5.7
# Install CUDA extensions for fused dense
RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.5.6#subdirectory=csrc/fused_dense_lib
RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.5.7#subdirectory=csrc/fused_dense_lib
__version__ = "2.5.6"
__version__ = "2.5.7"
from vllm_flash_attn.flash_attn_interface import (
flash_attn_func,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment