"vscode:/vscode.git/clone" did not exist on "dcc92d0ab6c4ce022162a23566d44f673251eee4"
Unverified Commit 21ec6e04 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

change softmax_lse correction of CP to FP32 (#1546)



* fix recompilation of out and lse correction in p2p+bshd/sbhd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix recompilation of get_seq_chunk_ids_for_reordering
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix recomplilation of reorder_seq_chunks_for_a2a
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* recover a change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change to softmax_lse correction
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* cache cu_seqlens for BSHD/SBHD format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* do not need to allocate out buffer for BSHD/SBHD
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code refactoring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* refactor init out correction
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix a docstring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code refactoring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix init out correct dtype
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add pad_between_seqs to DPA API
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add pad_between_seqs to the API of MHA and transformer layer
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add pad_between_seqs to the API of MHA and transformer layer
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not cast partial lse to FP64 for correction
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do lse correction in FP32 with THD format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* use log1pf and expf
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 0828aa86
......@@ -1247,7 +1247,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if fp8:
out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32)
if i == 1:
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
softmax_lse = torch.clone(softmax_lse_per_step[0])
if qkv_format == "thd":
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
elif (i - 1) <= rank or not causal:
......@@ -1277,7 +1277,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if causal and rank < (cp_size - 1):
second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1]
softmax_lse = softmax_lse.to(torch.float)
for i in range(cp_size):
if i <= rank or not causal:
if qkv_format in ["bshd", "sbhd"]:
......
......@@ -703,7 +703,7 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens, bool lse_packed) {
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double);
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
......@@ -742,14 +742,14 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st
dim3 grid = {grid_x, grid_y};
if (lse_packed) {
transformer_engine::fused_attn::thd_lse_kernel<double, true, LseCorrectionFunctor>
transformer_engine::fused_attn::thd_lse_kernel<true, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
lse.data_ptr<float>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, lse_seqlen, second_half_lse_seqlen);
} else {
transformer_engine::fused_attn::thd_lse_kernel<double, false, LseCorrectionFunctor>
transformer_engine::fused_attn::thd_lse_kernel<false, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
lse.data_ptr<float>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, lse_seqlen, second_half_lse_seqlen);
}
}
......@@ -794,12 +794,12 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_
dim3 grid = {grid_x, grid_y};
if (lse_packed) {
transformer_engine::fused_attn::thd_lse_kernel<float, true, ReadLseFunctor>
transformer_engine::fused_attn::thd_lse_kernel<true, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, lse_seqlen, second_half_lse_seqlen);
} else {
transformer_engine::fused_attn::thd_lse_kernel<float, false, ReadLseFunctor>
transformer_engine::fused_attn::thd_lse_kernel<false, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, lse_seqlen, second_half_lse_seqlen);
......
......@@ -11,13 +11,13 @@
#include <cuda_bf16.h>
struct LseCorrectionFunctor {
__forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx,
__forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx,
size_t half_idx) {
double val = lse[idx];
float val = lse[idx];
float val_per_step = half_lse[half_idx];
double max_scale = max(val, val_per_step);
double min_scale = min(val, val_per_step);
lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale));
float max_scale = max(val, val_per_step);
float min_scale = min(val, val_per_step);
lse[idx] = max_scale + log1pf(expf(min_scale - max_scale));
}
};
......@@ -148,8 +148,8 @@ __global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_se
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/
template <typename lse_dtype, bool lse_packed, typename Functor>
__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch,
template <bool lse_packed, typename Functor>
__global__ void thd_lse_kernel(float *lse, float *half_lse, int *cu_seqlens, int batch,
int num_heads, int lse_seqlen, int second_half_lse_seqlen) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
......@@ -218,7 +218,7 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
idx = row * lse_seqlen + col + seq_len * only_second_half;
idx_per_step = row * lse_per_step_seqlen + col;
}
float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]);
float lse_corrected_exp = expf(lse_per_step[idx_per_step] - lse[idx]);
idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
idx = (idx * num_heads + head_id) * dim_per_head;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment