Unverified Commit b36bd0a4 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Add FlashAttention3 to CP implementations (#1232)



* fa2 function import renaming
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* refine fa_fwd_kwargs and fa_bwd_kwargs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* import FA3 fucntions for CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix output of FA3 fwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix rng_state in a2a implementation with FA3
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* hack lse correction for packed lse format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make CP thd out correction work with packed lse
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix for packed softmax_lse
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix softmax_lse shape
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* change lse_packed to constexpr
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 9ee2dbdd
This diff is collapsed.
......@@ -433,14 +433,14 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
int half_idx);
void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens, int total_tokens);
const at::Tensor &cu_seqlens, bool lse_packed);
at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
int total_tokens);
bool lse_packed);
void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
bool only_second_half);
bool only_second_half, bool lse_packed);
void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens, const std::string &first_half,
......
......@@ -1464,9 +1464,9 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/
template <typename lse_dtype, typename Functor>
template <typename lse_dtype, bool lse_packed, typename Functor>
__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch,
int num_heads, int max_seqlen) {
int num_heads, int total_tokens) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
......@@ -1480,12 +1480,18 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens,
for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t idx, half_idx;
if constexpr (lse_packed) {
idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1];
half_idx = head_id * total_tokens / 2 + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
size_t idx = row * max_seqlen + col + seq_len;
size_t half_idx = row * max_seqlen / 2 + col;
idx = row * total_tokens + col + seq_len;
half_idx = row * total_tokens / 2 + col;
}
Functor::run(lse, half_lse, idx, half_idx);
}
......@@ -1504,32 +1510,53 @@ struct LseCorrectionFunctor {
};
void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens, int total_tokens) {
const at::Tensor &cu_seqlens, bool lse_packed) {
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double);
NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
int batch, num_heads, total_tokens;
if (lse_packed) {
NVTE_CHECK(lse.dim() == 2);
NVTE_CHECK(lse_per_step.dim() == 2);
batch = cu_seqlens.size(0) - 1;
num_heads = lse.size(0);
total_tokens = lse.size(1);
NVTE_CHECK(lse_per_step.size(0) == num_heads);
NVTE_CHECK(lse_per_step.size(1) == total_tokens / 2);
} else {
NVTE_CHECK(lse.dim() == 3);
NVTE_CHECK(lse_per_step.dim() == 3);
NVTE_CHECK(cu_seqlens.dim() == 1);
int batch = lse.size(0);
int num_heads = lse.size(1);
int max_seqlen = lse.size(2);
batch = lse.size(0);
num_heads = lse.size(1);
total_tokens = lse.size(2);
NVTE_CHECK(lse_per_step.size(0) == batch);
NVTE_CHECK(lse_per_step.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(2) == max_seqlen / 2);
NVTE_CHECK(lse_per_step.size(2) == total_tokens / 2);
NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
}
constexpr unsigned int block = 256;
unsigned int grid_x = (total_tokens / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
thd_lse_kernel<double, LseCorrectionFunctor>
if (lse_packed) {
thd_lse_kernel<double, true, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, max_seqlen);
lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, total_tokens);
} else {
thd_lse_kernel<double, false, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, total_tokens);
}
}
struct ReadLseFunctor {
......@@ -1540,29 +1567,51 @@ struct ReadLseFunctor {
};
at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
int total_tokens) {
bool lse_packed) {
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(lse.dim() == 3);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
int batch = lse.size(0);
int num_heads = lse.size(1);
int max_seqlen = lse.size(2);
int batch, num_heads, total_tokens;
std::vector<int64_t> shape;
if (lse_packed) {
NVTE_CHECK(lse.dim() == 2);
batch = cu_seqlens.size(0) - 1;
num_heads = lse.size(0);
total_tokens = lse.size(1);
shape = {num_heads, total_tokens / 2};
} else {
NVTE_CHECK(lse.dim() == 3);
batch = lse.size(0);
num_heads = lse.size(1);
total_tokens = lse.size(2);
NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
std::vector<int64_t> shape = {batch, num_heads, max_seqlen / 2};
shape = {batch, num_heads, total_tokens / 2};
}
at::Tensor half_lse = at::zeros(shape, at::CUDA(lse.scalar_type()));
constexpr unsigned int block = 256;
unsigned int grid_x = (total_tokens / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
thd_lse_kernel<float, ReadLseFunctor>
if (lse_packed) {
thd_lse_kernel<float, true, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, total_tokens);
} else {
thd_lse_kernel<float, false, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, max_seqlen);
num_heads, total_tokens);
}
return half_lse;
}
......@@ -1571,10 +1620,10 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_
* Support THD format for Context Parallel: Out correction in forward
**************************************************************************************************/
template <typename dtype, int only_second_half, int tile_size>
template <typename dtype, int only_second_half, int tile_size, bool lse_packed>
__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse,
float *lse_per_step, int *cu_seqlens, int batch,
int num_heads, int dim_per_head, int max_seqlen) {
int num_heads, int dim_per_head, int lse_seqlen) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1);
......@@ -1592,11 +1641,16 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t idx, idx_per_step;
if constexpr (lse_packed) {
idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
idx = row * max_seqlen + col + seq_len * only_second_half;
idx_per_step = row * max_seqlen / (only_second_half + 1) + col;
idx = row * lse_seqlen + col + seq_len * only_second_half;
idx_per_step = row * lse_seqlen / (only_second_half + 1) + col;
}
float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]);
idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
......@@ -1622,7 +1676,7 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
template <typename dtype, int only_second_half>
static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_step,
const at::Tensor &lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens) {
const at::Tensor &cu_seqlens, bool lse_packed) {
NVTE_CHECK(out.scalar_type() == out_per_step.scalar_type());
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float);
......@@ -1631,17 +1685,30 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_
int total_tokens = out.size(0);
int num_heads = out.size(1);
int dim_per_head = out.size(2);
int batch = lse.size(0);
int max_seqlen = lse.size(2);
NVTE_CHECK(out_per_step.size(0) == total_tokens / (only_second_half + 1));
NVTE_CHECK(out_per_step.size(1) == num_heads);
NVTE_CHECK(out_per_step.size(2) == dim_per_head);
int batch, lse_seqlen;
if (lse_packed) {
batch = cu_seqlens.size(0) - 1;
lse_seqlen = total_tokens;
NVTE_CHECK(lse.size(0) == num_heads);
NVTE_CHECK(lse.size(1) == lse_seqlen);
NVTE_CHECK(lse_per_step.size(0) == num_heads);
NVTE_CHECK(lse_per_step.size(1) == lse_seqlen / (only_second_half + 1));
} else {
batch = lse.size(0);
lse_seqlen = lse.size(2);
NVTE_CHECK(lse.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(0) == batch);
NVTE_CHECK(lse_per_step.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(2) == max_seqlen / (only_second_half + 1));
NVTE_CHECK(lse_per_step.size(2) == lse_seqlen / (only_second_half + 1));
NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
}
constexpr int tile = 16;
constexpr int block = 512;
......@@ -1649,39 +1716,53 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_
(static_cast<size_t>(total_tokens) / (only_second_half + 1) * tile + block - 1) / block;
dim3 grid = {grid_x, (unsigned int)num_heads};
thd_out_correction_kernel<dtype, only_second_half, tile>
if (lse_packed) {
thd_out_correction_kernel<dtype, only_second_half, tile, true>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads,
dim_per_head, max_seqlen);
dim_per_head, lse_seqlen);
} else {
thd_out_correction_kernel<dtype, only_second_half, tile, false>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads,
dim_per_head, lse_seqlen);
}
}
void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
bool only_second_half) {
bool only_second_half, bool lse_packed) {
if (only_second_half) {
if (out.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens);
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens);
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::Float) {
using dtype = float;
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens);
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else {
NVTE_ERROR("Unsupported dtype of out\n");
}
} else {
if (out.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens);
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens);
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::Float) {
using dtype = float;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens);
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else {
NVTE_ERROR("Unsupported dtype of out\n");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment