Commit b65e50ba authored by yuguo's avatar yuguo
Browse files

[DCU] fix merge

parent f8c2af4c
......@@ -189,7 +189,7 @@ __global__ void thd_lse_kernel(float *lse, float *half_lse, int *cu_seqlens, int
**************************************************************************************************/
template <typename dtype, int only_second_half, int tile_size, bool lse_packed>
__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse,
__global__ void __launch_bounds__(512) thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse,
float *lse_per_step, int *cu_seqlens, int batch,
int num_heads, int dim_per_head, int lse_seqlen,
int lse_per_step_seqlen) {
......
......@@ -57,7 +57,7 @@ reduce_block_into_lanes(T *x, T val, int lanes = 1,
// __SYNCWARP();
#ifdef __HIP_PLATFORM_AMD__
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down(final, i);
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down(final, i, THREADS_PER_WARP);
#else
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i);
......@@ -104,7 +104,7 @@ reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1,
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
#ifdef __HIP_PLATFORM_AMD__
final = fmaxf(fabsf(final), fabsf(__shfl_down(final, i)));
final = fmaxf(fabsf(final), fabsf(__shfl_down(final, i, THREADS_PER_WARP)));
#else
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
#endif
......
......@@ -52,7 +52,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
}
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
float other_amax = __shfl_down(amax, delta);
#else
float other_amax = __shfl_down_sync(0xFFFFFFFF, amax, delta);
#endif
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
......@@ -119,10 +123,18 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
}
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
bool other_skip_store = __shfl_down(skip_store, delta);
#else
bool other_skip_store = __shfl_down_sync(0xFFFFFFFF, skip_store, delta);
#endif
skip_store = skip_store && other_skip_store;
}
#ifdef __HIP_PLATFORM_AMD__
skip_store = __shfl(skip_store, 0);
#else
skip_store = __shfl_sync(0xFFFFFFFF, skip_store, 0);
#endif
if (skip_store) {
return;
}
......
......@@ -13,6 +13,7 @@ import logging
from packaging.version import Version as PkgVersion
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import (
SplitAlongDim,
......@@ -92,7 +93,7 @@ else:
fa_utils.set_flash_attention_version()
elif (
torch.cuda.is_available()
and get_device_compute_capability() >= (8, 0)
and (IS_HIP_EXTENSION or get_device_compute_capability() >= (8, 0))
and dpa_utils._NVTE_FLASH_ATTN
):
attn_log.fa_logger.warning(
......@@ -107,25 +108,26 @@ else:
),
fa_utils.version,
)
try:
fa_utils.fa3_version = PkgVersion(get_pkg_version("flash-attn-3"))
except PackageNotFoundError:
flash_attn_func_v3 = None
flash_attn_varlen_func_v3 = None
flash_attn_with_kvcache_v3 = None
# pass # only print warning if use_flash_attention_3 = True in get_attention_backend
else:
from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_3.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flash_attn_3.flash_attn_interface import (
flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
)
from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
fa_utils.set_flash_attention_3_params()
if not IS_HIP_EXTENSION:
try:
fa_utils.fa3_version = PkgVersion(get_pkg_version("flash-attn-3"))
except PackageNotFoundError:
flash_attn_func_v3 = None
flash_attn_varlen_func_v3 = None
flash_attn_with_kvcache_v3 = None
# pass # only print warning if use_flash_attention_3 = True in get_attention_backend
else:
from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_3.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flash_attn_3.flash_attn_interface import (
flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
)
from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
fa_utils.set_flash_attention_3_params()
class UnfusedDotProductAttention(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment