"vscode:/vscode.git/clone" did not exist on "aec1c211f10fae4723da5d73968c0c229ec282a2"
Unverified Commit d8f104e9 authored by rocking's avatar rocking Committed by GitHub
Browse files

Support AMD ROCm on FlashAttention 2 (#1010)



* Support ck in fmha

* Add ck submodule

* Do not return lse if return_softmax == false

* Use receipt to speed up ck compile time

* Integrate new version of ck_tile

* Support dropout for mha_fwd()

* Add dropout to mha_varlen_fwd()

* Update ck to develop

* Extract padding function for dropout randval

* Extract randval transformation function

* Sync the code structure and coding style with FA

* Remove this line, c++ api will handle this.
Sync with test_flash_attn.py

* fix compile error

* Add mha_bwd

* Generate dropout seed and offset from user generator

* update CK

* Add mha_varlen_bwd

* Use same python as build flash-attn to generate ck kernel

* Fix bug of group mode fwd about returning softmax lse

* larger the test tollerance

* Add test_flash_attn_output() and test_flash_attn_varlen_output()

* Always fill softmax_lse

* Remove duplicate benchmark script, since we already implement mha_bwd

* Refine get value from tuple

* Use default parameter for stream_config

* unblock all platform

* Add comment

* refine the test code

* Refine naming

* Add unpack to namespace

* Do not hardcode the warp size 64

* Add more targets

* Add README

* Optimize mha_fwd if seqlen_q == 1

* Support get_wheel_url for rocm

* Detect rocm environment by pytorch's IS_HIP_EXTENSION

* update to lastest ck

* Add necessary compile flag

* Sync the api with upstream FA

---------
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
Co-authored-by: default avatarYichen Yan <wenji.yyc@alibaba-inc.com>
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: default avatarYichen Yan <oraluben@outlook.com>
parent dfe1a59e
[submodule "csrc/cutlass"] [submodule "csrc/cutlass"]
path = csrc/cutlass path = csrc/cutlass
url = https://github.com/NVIDIA/cutlass.git url = https://github.com/NVIDIA/cutlass.git
[submodule "csrc/composable_kernel"]
path = csrc/composable_kernel
url = https://github.com/ROCm/composable_kernel.git
...@@ -434,6 +434,33 @@ This new release of FlashAttention-2 has been tested on several GPT-style ...@@ -434,6 +434,33 @@ This new release of FlashAttention-2 has been tested on several GPT-style
models, mostly on A100 GPUs. models, mostly on A100 GPUs.
If you encounter bugs, please open a GitHub Issue! If you encounter bugs, please open a GitHub Issue!
## AMD GPU/ROCm Support
ROCm version use [composable_kernel](https://github.com/ROCm/composable_kernel) as backend. It provides the implementation of FlashAttention-2.
## Installation and features
Requirements:
- ROCm 6.0+
- PyTorch 1.12.1+
We recommend the
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
container from ROCm, which has all the required tools to install FlashAttention.
To compile from source:
```sh
python setup.py install
```
FlashAttention-2 on ROCm currently supports:
1. MI200 or MI300 GPUs.
2. Datatype fp16 and bf16
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
## Tests
To run the tests:
```sh
pytest tests/test_flash_attn_ck.py
```
## Citation ## Citation
If you use this codebase, or otherwise found our work valuable, please cite: If you use this codebase, or otherwise found our work valuable, please cite:
......
Subproject commit 8182976c37433808b5e3a27a6536d1b74b0c23a1
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#include "flash_common.hpp"
std::vector<at::Tensor>
mha_fwd(at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
c10::optional<at::Tensor> &out_,
c10::optional<at::Tensor> &alibi_slopes_,
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_);
std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_);
std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x seqlen_q
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state);
std::vector<at::Tensor>
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &out, // total_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
}
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
namespace flash {
// Copy from PyTorch
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
static std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
if (arg.captured_) {
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
// For most threads' reads it will hit in cache, so it shouldn't hurt performance.
return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
} else {
return std::make_tuple(arg.seed_.val, arg.offset_.val);
}
}
} // namespace flash
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#include "flash_common.hpp"
#include "fmha_bwd.hpp"
#include "mask.hpp"
fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
std::string dtype,
int head_size,
bool has_dropout,
bool enable_alibi)
{
return fmha_bwd_traits{head_size,
head_size,
dtype,
false, // is_group_mode
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
false, // has_dbias
has_dropout};
}
fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
// sizes
const int b,
const int seqlen_q,
const int seqlen_k,
const int h,
const int h_k,
const int hdim,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
c10::optional<at::Tensor> &alibi_slopes_,
const at::Tensor out,
const at::Tensor softmax_lse,
const at::Tensor dout,
at::Tensor d,
at::Tensor dq,
at::Tensor dk,
at::Tensor dv,
float softmax_scale,
float p_dropout,
uint64_t drop_seed,
uint64_t drop_offset)
{
// q: (batch_size, seqlen_q, nheads, hdim)
// k: (batch_size, seqlen_k, nheads_k, hdim)
// v: (batch_size, seqlen_k, nheads_k, hdim)
// o: (batch_size, seqlen_q, nheads, hdim)
// dq: (batch_size, seqlen_q, nheads, hdim)
// dk_expanded: (batch_size, seqlen_k, nheads, hdim)
// dv_expanded: (batch_size, seqlen_k, nheads, hdim)
// do: (batch_size, seqlen_q, nheads, hdim)
// alibi_slopes:(batch_size, nheads) or (nhead)
// lse: (batch_size, nheads, seqlen_q)
// d: (batch_size, nheads, seqlen_q)
ck_tile::index_t stride_q = q.stride(1);
ck_tile::index_t stride_k = k.stride(1);
ck_tile::index_t stride_v = v.stride(1);
ck_tile::index_t stride_o = out.stride(1);
ck_tile::index_t stride_do = dout.stride(1);
ck_tile::index_t stride_dk = dk.stride(1);
ck_tile::index_t stride_dv = dv.stride(1);
ck_tile::index_t nhead_stride_q = q.stride(2);
ck_tile::index_t nhead_stride_k = k.stride(2);
ck_tile::index_t nhead_stride_v = v.stride(2);
ck_tile::index_t nhead_stride_o = out.stride(2);
ck_tile::index_t nhead_stride_do = dout.stride(2);
ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);
ck_tile::index_t batch_stride_q = q.stride(0);
ck_tile::index_t batch_stride_k = k.stride(0);
ck_tile::index_t batch_stride_v = v.stride(0);
ck_tile::index_t batch_stride_o = out.stride(0);
ck_tile::index_t batch_stride_do = dout.stride(0);
ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);
ck_tile::index_t batch_stride_dk = dk.stride(0);
ck_tile::index_t batch_stride_dv = dv.stride(0);
float p_undrop = 1.0 - p_dropout;
void *alibi_slopes_ptr = nullptr;
ck_tile::index_t stride_alibi_slopes = 0;
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
alibi_slopes_ptr = alibi_slopes.data_ptr();
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
}
return fmha_bwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
out.data_ptr(),
softmax_lse.data_ptr(),
dout.data_ptr(),
d.data_ptr(),
nullptr, // rand_val
dq.data_ptr(),
dk.data_ptr(),
dv.data_ptr(),
nullptr, // dbias
nullptr, // seqstart_q
nullptr, // seqstart_k
nullptr, // seqlen_k_ptr
seqlen_q,
seqlen_k,
b,
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
hdim, // hdim_q
hdim, // hdim_v
h, // nhead
h_k, // nhead_k
softmax_scale,
stride_q,
stride_k,
stride_v,
stride_alibi_slopes,
stride_o,
0, // stride_randval
stride_do,
stride_dk,
stride_dv,
0, // stride_dbias, FA without bias
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_o,
0, // nhead_stride_randval
nhead_stride_do,
nhead_stride_lse,
0, // nhead_stride_dbias, FA without dbias
batch_stride_q,
batch_stride_k,
batch_stride_v,
0 , // batch_stride_bias, FA without bias
batch_stride_o,
0, // batch_stride_randval
batch_stride_do,
batch_stride_lse,
batch_stride_dk,
batch_stride_dv,
0 , // batch_stride_dbias, FA without dbias
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
p_undrop,
false, // s_randval
{drop_seed, drop_offset}};
}
std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x seqlen_q
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
int window_size_left,
int window_size_right,
const float /*softcap*/,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state)
{
#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
#endif
if (is_causal) { window_size_right = 0; }
bool is_dropout = p_dropout > 0.0;
auto stream = at::cuda::getCurrentHIPStream().stream();
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q = sizes[1];
const int num_heads = sizes[2];
const int head_size_og = dout.size(3); // unpadded hdim
const int head_size_8x = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_8x % 8 == 0, "head_size_8x should be a multiple of 8");
TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8");
if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; }
mask_info mask;
if (is_causal) {
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual
}
else if (window_size_left == -1 && window_size_right == -1) {
mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask
}
else {
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local
}
// q, k, v, out had been padded in mha_fwd
// dq_, dk_, dv_ are also padded tensor
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_8x);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_8x);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_8x);
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_8x);
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
at::Tensor dq, dk, dv;
if (dq_.has_value()) {
dq = dq_.value();
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
CHECK_DEVICE(dq);
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size_8x);
} else {
dq = torch::empty_like(q);
}
if (dk_.has_value()) {
dk = dk_.value();
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
CHECK_DEVICE(dk);
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size_8x);
} else {
dk = torch::empty_like(k);
}
if (dv_.has_value()) {
dv = dv_.value();
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
CHECK_DEVICE(dv);
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_8x);
} else {
dv = torch::empty_like(v);
}
at::Tensor dout_padded;
if (head_size_og % 8 != 0) {
dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
} else {
dout_padded = dout;
}
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
// TODO - CK does not support dq_accum
at::Tensor dk_expanded, dv_expanded;
if (num_heads_k != num_heads) { // MQA / GQA
dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts);
dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts);
} else {
dk_expanded = dk;
dv_expanded = dv;
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
if (rng_state.has_value()) {
uint64_t* d = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
drop_seed = d[0];
drop_offset = d[1];
} else if(is_dropout) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
}
if (seqlen_q > 0) {
ck_tile::stream_config stream_config{stream};
dq.zero_(); // ck use atomic operation on dq
auto traits =
get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value());
auto args =
get_ck_fmha_bwd_args(
mask,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size_8x,
q,
k,
v,
alibi_slopes_,
out,
softmax_lse,
dout_padded,
softmax_d,
dq,
dk_expanded,
dv_expanded,
softmax_scale,
p_dropout,
drop_seed,
drop_offset);
fmha_bwd(traits, args, stream_config);
} else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk_expanded.zero_();
dv_expanded.zero_();
softmax_d.zero_();
}
// For MQA/GQA we need to sum dK and dV across the groups
if (num_heads_k != num_heads) {
at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3});
at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3});
}
if (head_size_og % 8 != 0) {
dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
}
return { dq, dk, dv, softmax_d };
}
\ No newline at end of file
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#include "flash_common.hpp"
#include "fmha_fwd.hpp"
#include "mask.hpp"
fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
std::string dtype,
int head_size,
bool has_dropout,
bool has_lse,
bool enable_alibi)
{
return fmha_fwd_traits{head_size,
head_size,
dtype,
false, // is_group_mode
true, // is_v_rowmajor
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
has_lse,
has_dropout,
false}; // do_fp8_static_quant
}
fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
bool has_dropout_randval,
const mask_info &mask,
// sizes
const int b,
const int seqlen_q,
const int seqlen_k,
const int h,
const int h_k,
const int d,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
c10::optional<at::Tensor> &alibi_slopes_,
at::Tensor out,
at::Tensor softmax_lse,
at::Tensor dropout_randval,
float softmax_scale,
float p_dropout,
uint64_t drop_seed,
uint64_t drop_offset)
{
// q: (batch_size, seqlen_q, nheads, d)
// k: (batch_size, seqlen_k, nheads_k, d)
// v: (batch_size, seqlen_k, nheads_k, d)
// o: (batch_size, seqlen_q, nheads, d)
// alibi_slopes:(batch_size, nheads) or (nhead)
// lse: (batch_size, nheads, seqlen_q)
// randval: (batch_size, nheads, seqlen_q, seqlen_k)
ck_tile::index_t stride_q = q.stride(1);
ck_tile::index_t stride_k = k.stride(1);
ck_tile::index_t stride_v = v.stride(1);
ck_tile::index_t stride_o = out.stride(1);
ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(2) : 0;
ck_tile::index_t nhead_stride_q = q.stride(2);
ck_tile::index_t nhead_stride_k = k.stride(2);
ck_tile::index_t nhead_stride_v = v.stride(2);
ck_tile::index_t nhead_stride_o = out.stride(2);
ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0;
ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0;
ck_tile::index_t batch_stride_q = q.stride(0);
ck_tile::index_t batch_stride_k = k.stride(0);
ck_tile::index_t batch_stride_v = v.stride(0);
ck_tile::index_t batch_stride_o = out.stride(0);
ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
ck_tile::index_t batch_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;
void *alibi_slopes_ptr = nullptr;
ck_tile::index_t stride_alibi_slopes = 0;
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
alibi_slopes_ptr = alibi_slopes.data_ptr();
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
}
return fmha_fwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
nullptr, // lse_acc
nullptr, // o_acc
has_lse ? softmax_lse.data_ptr() : nullptr,
out.data_ptr(),
nullptr, // seqstart_q
nullptr, // seqstart_k
nullptr,
seqlen_q,
seqlen_k,
b,
seqlen_q, // max_seqlen_q
d, // hdim_q
d, // hdim_v
h, // nhead
h_k, // nhead_k
1, // num_splits
softmax_scale, // scale_s
1, // scale_p
1, // scale_o
stride_q,
stride_k,
stride_v,
stride_alibi_slopes,
stride_randval,
0, // stride_o_acc,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_randval,
nhead_stride_lse,
0, // nhead_stride_lse_acc
0, // nhead_stride_o_acc
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
0, // batch_stride_bias, FA without bias
batch_stride_randval,
batch_stride_lse,
0, // batch_stride_lse_acc
0, // batch_stride_o_acc
batch_stride_o,
0, // split_stride_lse_acc
0, // split_stride_o_acc
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
has_dropout_randval,
{drop_seed, drop_offset}};
}
std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float /*softcap*/,
const bool return_dropout_randval,
c10::optional<at::Generator> gen_)
{
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
int seqlen_q = sizes[1];
int num_heads = sizes[2];
const int head_size_og = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; }
// causal=true is the same as causal=false in this case
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
mask_info mask;
if (is_causal) {
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
window_size_right = 0;
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual
}
else if (window_size_left == -1 && window_size_right == -1) {
mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask
}
else {
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local
}
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) {
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
seqlen_q = ngroups;
num_heads = num_heads_k;
}
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
at::Tensor q_padded, k_padded, v_padded;
if (head_size_og % 8 != 0) {
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
}
else {
q_padded = q;
k_padded = k;
v_padded = v;
}
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
}
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
}
else {
out = torch::empty_like(q_padded);
}
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_8x = round_multiple(head_size_og, 8);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
bool has_lse = true;
bool has_dropout = p_dropout > 0.0f;
at::Tensor softmax_lse;
// TODO - check gradient, only training require lse
softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(torch::kFloat32));
at::Tensor p;
if (return_dropout_randval) {
TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0");
p = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(torch::kUInt8));
}
uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
}
rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
if (seqlen_k > 0) {
auto stream = at::cuda::getCurrentHIPStream().stream();
ck_tile::stream_config stream_config{stream};
auto traits =
get_ck_fmha_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value());
auto args =
get_ck_fmha_fwd_args(
has_lse,
return_dropout_randval,
mask,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size_8x,
q_padded,
k_padded,
v_padded,
alibi_slopes_,
out,
softmax_lse,
p,
softmax_scale,
p_dropout,
drop_seed,
drop_offset);
fmha_fwd(traits, args, stream_config);
}
else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}
at::Tensor out_padded = out;
if (head_size_og % 8 != 0) {
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
if (out_.has_value()) { out_.value().copy_(out); }
}
if (seqlenq_ngroups_swapped) {
out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
}
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
}
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#include "flash_common.hpp"
#include "fmha_bwd.hpp"
#include "mask.hpp"
fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask,
std::string dtype,
int head_size,
bool has_dropout,
bool enable_alibi)
{
return fmha_bwd_traits{head_size,
head_size,
dtype,
true, // is_group_mode
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
false, // has_dbias
has_dropout};
}
fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
// sizes
const int b,
const int max_seqlen_q,
const int max_seqlen_k,
const int h,
const int h_k,
const int hdim,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
const at::Tensor seqlens_q,
const at::Tensor seqlens_k,
c10::optional<at::Tensor> &alibi_slopes_,
const at::Tensor out,
const at::Tensor softmax_lse,
const at::Tensor dout,
at::Tensor d,
at::Tensor dq,
at::Tensor dk,
at::Tensor dv,
float softmax_scale,
float p_dropout,
uint64_t drop_seed,
uint64_t drop_offset)
{
// q: (total_q, nheads, hdim)
// k: (total_k, nheads_k, hdim)
// v: (total_k, nheads_k, hdim)
// o: (total_q, nheads, hdim)
// dq: (total_q, nheads, hdim)
// dk_expanded: (total_k, nheads, hdim)
// dv_expanded: (total_k, nheads, hdim)
// do: (total_q, nheads, hdim)
// alibi_slopes:(batch_size, nheads) or (nhead)
// lse: (batch_size, nheads, max_seqlen_q)
// d: (batch_size, nheads, max_seqlen_q)
ck_tile::index_t total_q = q.size(0);
ck_tile::index_t total_k = k.size(0);
ck_tile::index_t stride_q = q.stride(0);
ck_tile::index_t stride_k = k.stride(0);
ck_tile::index_t stride_v = v.stride(0);
ck_tile::index_t stride_o = out.stride(0);
ck_tile::index_t stride_do = dout.stride(0);
ck_tile::index_t stride_dk = dk.stride(0);
ck_tile::index_t stride_dv = dv.stride(0);
ck_tile::index_t nhead_stride_q = q.stride(1);
ck_tile::index_t nhead_stride_k = k.stride(1);
ck_tile::index_t nhead_stride_v = v.stride(1);
ck_tile::index_t nhead_stride_o = out.stride(1);
ck_tile::index_t nhead_stride_do = dout.stride(1);
ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);
ck_tile::index_t batch_stride_q = 0;
ck_tile::index_t batch_stride_k = 0;
ck_tile::index_t batch_stride_v = 0;
ck_tile::index_t batch_stride_o = 0;
ck_tile::index_t batch_stride_do = 0;
ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);;
ck_tile::index_t batch_stride_dk = 0;
ck_tile::index_t batch_stride_dv = 0;
float p_undrop = 1.0 - p_dropout;
void *alibi_slopes_ptr = nullptr;
ck_tile::index_t stride_alibi_slopes = 0;
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
alibi_slopes_ptr = alibi_slopes.data_ptr();
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
}
return fmha_bwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
out.data_ptr(),
softmax_lse.data_ptr(),
dout.data_ptr(),
d.data_ptr(),
nullptr, // rand_val
dq.data_ptr(),
dk.data_ptr(),
dv.data_ptr(),
nullptr, // dbias
seqlens_q.data_ptr(), // seqstart_q
seqlens_k.data_ptr(), // seqstart_k
nullptr, // seqlen_k_ptr
total_q,
total_k,
b,
max_seqlen_q, // max_seqlen_q
max_seqlen_k, // max_seqlen_k
hdim, // hdim_q
hdim, // hdim_v
h, // nhead
h_k, // nhead_k
softmax_scale,
stride_q,
stride_k,
stride_v,
stride_alibi_slopes,
stride_o,
0, // stride_randval
stride_do,
stride_dk,
stride_dv,
0, // stride_dbias, FA without bias
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_o,
0, // nhead_stride_randval
nhead_stride_do,
nhead_stride_lse,
0, // nhead_stride_dbias, FA without dbias
batch_stride_q,
batch_stride_k,
batch_stride_v,
0 , // batch_stride_bias, FA without bias
batch_stride_o,
0, // batch_stride_randval
batch_stride_do,
batch_stride_lse,
batch_stride_dk,
batch_stride_dv,
0 , // batch_stride_dbias, FA without dbias
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
p_undrop,
false, // s_randval
{drop_seed, drop_offset}};
}
std::vector<at::Tensor>
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &out, // total_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
int window_size_left,
int window_size_right,
const float /*softcap*/,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state)
{
#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
#endif
if (is_causal) { window_size_right = 0; }
bool is_dropout = p_dropout > 0.0;
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);
const auto sizes = q.sizes();
const int total_q = sizes[0];
const int batch_size = cu_seqlens_q.numel() - 1;
const int num_heads = sizes[1];
const int head_size_og = dout.size(2);
const int head_size_8x = sizes[2];
const int total_k = k.size(0);
const int num_heads_k = k.size(1);
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_8x % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8");
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
mask_info mask;
if (is_causal) {
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual
}
else if (window_size_left == -1 && window_size_right == -1) {
mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask
}
else {
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
}
// q, k, v, out had been padded in mha_fwd
// dq_, dk_, dv_ are also padded tensor
CHECK_SHAPE(q, total_q, num_heads, head_size_8x);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_8x);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_8x);
CHECK_SHAPE(out, total_q, num_heads, head_size_8x);
CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
at::Tensor dq, dk, dv;
if (dq_.has_value()) {
dq = dq_.value();
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
CHECK_DEVICE(dq);
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
CHECK_SHAPE(dq, total_q, num_heads, head_size_8x);
} else {
dq = torch::empty_like(q);
}
if (dk_.has_value()) {
dk = dk_.value();
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
CHECK_DEVICE(dk);
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
CHECK_SHAPE(dk, total_k, num_heads_k, head_size_8x);
} else {
dk = torch::empty_like(k);
}
if (dv_.has_value()) {
dv = dv_.value();
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
CHECK_DEVICE(dv);
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
CHECK_SHAPE(dv, total_k, num_heads_k, head_size_8x);
} else {
dv = torch::empty_like(v);
}
at::Tensor dout_padded;
if (head_size_og % 8 != 0) {
dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
} else {
dout_padded = dout;
}
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
// TODO - CK does not support dq_accum
at::Tensor dk_expanded, dv_expanded;
if (num_heads_k != num_heads) { // MQA / GQA
dk_expanded = torch::empty({total_k, num_heads, head_size_8x}, opts);
dv_expanded = torch::empty({total_k, num_heads, head_size_8x}, opts);
} else {
dk_expanded = dk;
dv_expanded = dv;
}
if(zero_tensors) {
dq.zero_();
dk_expanded.zero_();
dv_expanded.zero_();
softmax_d.zero_();
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
if (rng_state.has_value()) {
uint64_t* d = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
drop_seed = d[0];
drop_offset = d[1];
} else if(is_dropout) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
}
if (max_seqlen_q > 0) {
ck_tile::stream_config stream_config{stream};
dq.zero_(); // ck use atomic operation on dq
auto traits =
get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value());
auto args =
get_ck_fmha_varlen_bwd_args(
mask,
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads,
num_heads_k,
head_size_8x,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes_,
out,
softmax_lse,
dout_padded,
softmax_d,
dq,
dk_expanded,
dv_expanded,
softmax_scale,
p_dropout,
drop_seed,
drop_offset);
fmha_bwd(traits, args, stream_config);
} else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk_expanded.zero_();
dv_expanded.zero_();
softmax_d.zero_();
}
// For MQA/GQA we need to sum dK and dV across the groups
if (num_heads_k != num_heads) {
at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2});
at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2});
}
if (head_size_og % 8 != 0) {
dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
}
return { dq, dk, dv, softmax_d };
}
\ No newline at end of file
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#include "flash_common.hpp"
#include "fmha_fwd.hpp"
#include "mask.hpp"
fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
std::string dtype,
int head_size,
bool has_dropout,
bool has_lse,
bool enable_alibi)
{
return fmha_fwd_traits{head_size,
head_size,
dtype,
true, // is_group_mode
true, // is_v_rowmajor
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
has_lse,
has_dropout,
false}; // do_fp8_static_quant
}
fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
bool has_dropout_randval,
const mask_info &mask,
// sizes
const int b,
const int max_seqlen_q,
const int h,
const int h_k,
const int d,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
const at::Tensor seqlens_q,
const at::Tensor seqlens_k,
c10::optional<at::Tensor> &alibi_slopes_,
at::Tensor out,
at::Tensor softmax_lse,
at::Tensor dropout_randval,
float softmax_scale,
float p_dropout,
uint64_t drop_seed,
uint64_t drop_offset)
{
// q: (total_q, nheads, d)
// k: (total_k, nheads_k, d)
// v: (total_k, nheads_k, d)
// o: (total_q, nheads, d)
// alibi_slopes:(batch, nheads) or (nhead)
// lse: (batch, nheads, max_seqlen_q)
// randval: (nheads, total_q, max_seqlen_k)
ck_tile::index_t total_q = q.size(0);
ck_tile::index_t total_k = k.size(0);
ck_tile::index_t stride_q = q.stride(0);
ck_tile::index_t stride_k = k.stride(0);
ck_tile::index_t stride_v = v.stride(0);
ck_tile::index_t stride_o = out.stride(0);
ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0;
ck_tile::index_t nhead_stride_q = q.stride(1);
ck_tile::index_t nhead_stride_k = k.stride(1);
ck_tile::index_t nhead_stride_v = v.stride(1);
ck_tile::index_t nhead_stride_o = out.stride(1);
ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0;
ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;
ck_tile::index_t batch_stride_q = 0;
ck_tile::index_t batch_stride_k = 0;
ck_tile::index_t batch_stride_v = 0;
ck_tile::index_t batch_stride_o = 0;
ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
ck_tile::index_t batch_stride_randval = 0;
void *alibi_slopes_ptr = nullptr;
ck_tile::index_t stride_alibi_slopes = 0;
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
alibi_slopes_ptr = alibi_slopes.data_ptr();
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
}
return fmha_fwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
nullptr, // lse_acc
nullptr, // o_acc
has_lse ? softmax_lse.data_ptr() : nullptr,
out.data_ptr(),
seqlens_q.data_ptr(), // seqstart_q
seqlens_k.data_ptr(), // seqstart_k
nullptr, // seqlen_kpads
total_q,
total_k,
b,
max_seqlen_q,
d, // hdim_q
d, // hdim_v
h, // nhead
h_k, // nhead_k
1, // num_splits
softmax_scale, // scale_s
1, // scale_p
1, // scale_o
stride_q,
stride_k,
stride_v,
stride_alibi_slopes,
stride_randval,
0, // stride_o_acc,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_randval,
nhead_stride_lse,
0, // nhead_stride_lse_acc
0, // nhead_stride_o_acc
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
0, // batch_stride_bias, FA without bias
batch_stride_randval,
batch_stride_lse,
0, // batch_stride_lse_acc
0, // batch_stride_o_acc
batch_stride_o,
0, // split_stride_lse_acc
0, // split_stride_o_acc
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
has_dropout_randval,
{drop_seed, drop_offset}};
}
std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> & /*seqused_k*/,
c10::optional<const at::Tensor> &/*leftpad_k_*/, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const float /*softcap*/,
const bool return_dropout_randval,
c10::optional<at::Generator> gen_)
{
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(cu_seqlens_q);
CHECK_DEVICE(cu_seqlens_k);
// TODO - Support paged_KV
const bool paged_KV = block_table_.has_value();
TORCH_CHECK(!paged_KV, "CK does not support paged_KV yet");
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);
const auto sizes = q.sizes();
const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1];
const int head_size_og = sizes[2];
const int num_heads_k = k.size(1);
const int max_num_blocks_per_seq = 0;
const int num_blocks = 0;
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
// TODO
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int total_q = q.size(0);
const int total_k = k.size(0);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
mask_info mask;
if (is_causal) {
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
window_size_right = 0;
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual
}
else if (window_size_left == -1 && window_size_right == -1) {
mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask
}
else {
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
}
CHECK_SHAPE(q, total_q, num_heads, head_size_og);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
at::Tensor q_padded, k_padded, v_padded;
if (head_size_og % 8 != 0) {
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
}
else {
q_padded = q;
k_padded = k;
v_padded = v;
}
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
}
else {
out = torch::empty_like(q_padded);
}
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_8x = round_multiple(head_size_og, 8);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
bool has_lse = true;
bool has_dropout = p_dropout > 0.0f;
at::Tensor softmax_lse;
// TODO - check gradient, only training require lse
softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(torch::kFloat32));
at::Tensor p;
if (return_dropout_randval) {
TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0");
p = torch::empty({num_heads, total_q, max_seqlen_k}, opts.dtype(torch::kUInt8));
}
if (zero_tensors)
{
out.zero_();
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
if (return_dropout_randval) {p.zero_();}
}
uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
}
rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentHIPStream().stream();
ck_tile::stream_config stream_config{stream};
auto traits =
get_ck_fmha_varlen_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value());
auto args =
get_ck_fmha_varlen_fwd_args(
has_lse,
return_dropout_randval,
mask,
batch_size,
max_seqlen_q,
num_heads,
num_heads_k,
head_size_8x,
q_padded,
k_padded,
v_padded,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes_,
out,
softmax_lse,
p,
softmax_scale,
p_dropout,
drop_seed,
drop_offset);
fmha_fwd(traits, args, stream_config);
}
else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}
at::Tensor out_padded = out;
if (head_size_og % 8 != 0) {
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
if (out_.has_value()) { out_.value().copy_(out); }
}
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
}
...@@ -5,6 +5,8 @@ import warnings ...@@ -5,6 +5,8 @@ import warnings
import os import os
import re import re
import ast import ast
import glob
import shutil
from pathlib import Path from pathlib import Path
from packaging.version import parse, Version from packaging.version import parse, Version
import platform import platform
...@@ -22,6 +24,8 @@ from torch.utils.cpp_extension import ( ...@@ -22,6 +24,8 @@ from torch.utils.cpp_extension import (
CppExtension, CppExtension,
CUDAExtension, CUDAExtension,
CUDA_HOME, CUDA_HOME,
ROCM_HOME,
IS_HIP_EXTENSION,
) )
...@@ -32,6 +36,19 @@ with open("README.md", "r", encoding="utf-8") as fh: ...@@ -32,6 +36,19 @@ with open("README.md", "r", encoding="utf-8") as fh:
# ninja build does not work unless include_dirs are abs path # ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto")
if BUILD_TARGET == "auto":
if IS_HIP_EXTENSION:
IS_ROCM = True
else:
IS_ROCM = False
else:
if BUILD_TARGET == "cuda":
IS_ROCM = False
elif BUILD_TARGET == "rocm":
IS_ROCM = True
PACKAGE_NAME = "flash_attn" PACKAGE_NAME = "flash_attn"
BASE_WHEEL_URL = ( BASE_WHEEL_URL = (
...@@ -82,19 +99,47 @@ def check_if_cuda_home_none(global_option: str) -> None: ...@@ -82,19 +99,47 @@ def check_if_cuda_home_none(global_option: str) -> None:
) )
def check_if_rocm_home_none(global_option: str) -> None:
if ROCM_HOME is not None:
return
# warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
# in that case.
warnings.warn(
f"{global_option} was requested, but hipcc was not found."
)
def append_nvcc_threads(nvcc_extra_args): def append_nvcc_threads(nvcc_extra_args):
nvcc_threads = os.getenv("NVCC_THREADS") or "4" nvcc_threads = os.getenv("NVCC_THREADS") or "4"
return nvcc_extra_args + ["--threads", nvcc_threads] return nvcc_extra_args + ["--threads", nvcc_threads]
def rename_cpp_to_cu(cpp_files):
for entry in cpp_files:
shutil.copy(entry, os.path.splitext(entry)[0] + ".cu")
def validate_and_update_archs(archs):
# List of allowed architectures
allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"]
# Validate if each element in archs is in allowed_archs
assert all(
arch in allowed_archs for arch in archs
), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention"
cmdclass = {} cmdclass = {}
ext_modules = [] ext_modules = []
# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
# files included in the source distribution, in case the user compiles from source. # files included in the source distribution, in case the user compiles from source.
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) if IS_ROCM:
subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"])
else:
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
if not SKIP_CUDA_BUILD: if not SKIP_CUDA_BUILD and not IS_ROCM:
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1]) TORCH_MINOR = int(torch.__version__.split(".")[1])
...@@ -250,6 +295,95 @@ if not SKIP_CUDA_BUILD: ...@@ -250,6 +295,95 @@ if not SKIP_CUDA_BUILD:
], ],
) )
) )
elif not SKIP_CUDA_BUILD and IS_ROCM:
ck_dir = "csrc/composable_kernel"
#use codegen get code dispatch
if not os.path.exists("./build"):
os.makedirs("build")
os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2")
os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2")
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"]
check_if_rocm_home_none("flash_attn")
cc_flag = []
archs = os.getenv("GPU_ARCHS", "native").split(";")
validate_and_update_archs(archs)
cc_flag = [f"--offload-arch={arch}" for arch in archs]
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
if FORCE_CXX11_ABI:
torch._C._GLIBCXX_USE_CXX11_ABI = True
sources = ["csrc/flash_attn_ck/flash_api.cpp",
"csrc/flash_attn_ck/mha_bwd.cpp",
"csrc/flash_attn_ck/mha_fwd.cpp",
"csrc/flash_attn_ck/mha_varlen_bwd.cpp",
"csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob(
f"build/fmha_*wd*.cpp"
)
rename_cpp_to_cu(sources)
renamed_sources = ["csrc/flash_attn_ck/flash_api.cu",
"csrc/flash_attn_ck/mha_bwd.cu",
"csrc/flash_attn_ck/mha_fwd.cu",
"csrc/flash_attn_ck/mha_varlen_bwd.cu",
"csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu")
extra_compile_args = {
"cxx": ["-O3", "-std=c++17"] + generator_flag,
"nvcc":
[
"-O3","-std=c++17",
"-mllvm", "-enable-post-misched=0",
"-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
"-fgpu-flush-denormals-to-zero",
"-DCK_ENABLE_BF16",
"-DCK_ENABLE_BF8",
"-DCK_ENABLE_FP16",
"-DCK_ENABLE_FP32",
"-DCK_ENABLE_FP64",
"-DCK_ENABLE_FP8",
"-DCK_ENABLE_INT8",
"-DCK_USE_XDL",
"-DUSE_PROF_API=1",
"-D__HIP_PLATFORM_HCC__=1",
# "-DFLASHATTENTION_DISABLE_BACKWARD",
]
+ generator_flag
+ cc_flag
,
}
include_dirs = [
Path(this_dir) / "csrc" / "composable_kernel" / "include",
Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include",
Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha",
]
ext_modules.append(
CUDAExtension(
name="flash_attn_2_cuda",
sources=renamed_sources,
extra_compile_args=extra_compile_args,
include_dirs=include_dirs,
)
)
def get_package_version(): def get_package_version():
...@@ -264,25 +398,33 @@ def get_package_version(): ...@@ -264,25 +398,33 @@ def get_package_version():
def get_wheel_url(): def get_wheel_url():
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version = parse(torch.version.cuda)
torch_version_raw = parse(torch.__version__) torch_version_raw = parse(torch.__version__)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
# to save CI time. Minor versions should be compatible.
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform() platform_name = get_platform()
flash_version = get_package_version() flash_version = get_package_version()
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
# Determine wheel URL based on CUDA version, torch version, python version and OS if IS_ROCM:
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" torch_hip_version = parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+'))
hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+rocm{hip_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
else:
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version = parse(torch.version.cuda)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
# to save CI time. Minor versions should be compatible.
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename) wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
return wheel_url, wheel_filename return wheel_url, wheel_filename
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment