Commit b65e50ba authored by yuguo's avatar yuguo
Browse files

[DCU] fix merge

parent f8c2af4c
...@@ -189,7 +189,7 @@ __global__ void thd_lse_kernel(float *lse, float *half_lse, int *cu_seqlens, int ...@@ -189,7 +189,7 @@ __global__ void thd_lse_kernel(float *lse, float *half_lse, int *cu_seqlens, int
**************************************************************************************************/ **************************************************************************************************/
template <typename dtype, int only_second_half, int tile_size, bool lse_packed> template <typename dtype, int only_second_half, int tile_size, bool lse_packed>
__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, __global__ void __launch_bounds__(512) thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse,
float *lse_per_step, int *cu_seqlens, int batch, float *lse_per_step, int *cu_seqlens, int batch,
int num_heads, int dim_per_head, int lse_seqlen, int num_heads, int dim_per_head, int lse_seqlen,
int lse_per_step_seqlen) { int lse_per_step_seqlen) {
......
...@@ -57,7 +57,7 @@ reduce_block_into_lanes(T *x, T val, int lanes = 1, ...@@ -57,7 +57,7 @@ reduce_block_into_lanes(T *x, T val, int lanes = 1,
// __SYNCWARP(); // __SYNCWARP();
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
#pragma unroll #pragma unroll
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down(final, i); for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down(final, i, THREADS_PER_WARP);
#else #else
#pragma unroll #pragma unroll
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i); for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i);
...@@ -104,7 +104,7 @@ reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1, ...@@ -104,7 +104,7 @@ reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1,
#pragma unroll #pragma unroll
for (int i = 16; i >= lanes; i >>= 1) for (int i = 16; i >= lanes; i >>= 1)
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
final = fmaxf(fabsf(final), fabsf(__shfl_down(final, i))); final = fmaxf(fabsf(final), fabsf(__shfl_down(final, i, THREADS_PER_WARP)));
#else #else
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
#endif #endif
......
...@@ -52,7 +52,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -52,7 +52,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
} }
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
float other_amax = __shfl_down(amax, delta);
#else
float other_amax = __shfl_down_sync(0xFFFFFFFF, amax, delta); float other_amax = __shfl_down_sync(0xFFFFFFFF, amax, delta);
#endif
__builtin_assume(amax >= 0); __builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0); __builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax); amax = fmaxf(amax, other_amax);
...@@ -119,10 +123,18 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -119,10 +123,18 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
} }
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
bool other_skip_store = __shfl_down(skip_store, delta);
#else
bool other_skip_store = __shfl_down_sync(0xFFFFFFFF, skip_store, delta); bool other_skip_store = __shfl_down_sync(0xFFFFFFFF, skip_store, delta);
#endif
skip_store = skip_store && other_skip_store; skip_store = skip_store && other_skip_store;
} }
#ifdef __HIP_PLATFORM_AMD__
skip_store = __shfl(skip_store, 0);
#else
skip_store = __shfl_sync(0xFFFFFFFF, skip_store, 0); skip_store = __shfl_sync(0xFFFFFFFF, skip_store, 0);
#endif
if (skip_store) { if (skip_store) {
return; return;
} }
......
...@@ -13,6 +13,7 @@ import logging ...@@ -13,6 +13,7 @@ import logging
from packaging.version import Version as PkgVersion from packaging.version import Version as PkgVersion
import torch import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
SplitAlongDim, SplitAlongDim,
...@@ -92,7 +93,7 @@ else: ...@@ -92,7 +93,7 @@ else:
fa_utils.set_flash_attention_version() fa_utils.set_flash_attention_version()
elif ( elif (
torch.cuda.is_available() torch.cuda.is_available()
and get_device_compute_capability() >= (8, 0) and (IS_HIP_EXTENSION or get_device_compute_capability() >= (8, 0))
and dpa_utils._NVTE_FLASH_ATTN and dpa_utils._NVTE_FLASH_ATTN
): ):
attn_log.fa_logger.warning( attn_log.fa_logger.warning(
...@@ -107,14 +108,15 @@ else: ...@@ -107,14 +108,15 @@ else:
), ),
fa_utils.version, fa_utils.version,
) )
try: if not IS_HIP_EXTENSION:
try:
fa_utils.fa3_version = PkgVersion(get_pkg_version("flash-attn-3")) fa_utils.fa3_version = PkgVersion(get_pkg_version("flash-attn-3"))
except PackageNotFoundError: except PackageNotFoundError:
flash_attn_func_v3 = None flash_attn_func_v3 = None
flash_attn_varlen_func_v3 = None flash_attn_varlen_func_v3 = None
flash_attn_with_kvcache_v3 = None flash_attn_with_kvcache_v3 = None
# pass # only print warning if use_flash_attention_3 = True in get_attention_backend # pass # only print warning if use_flash_attention_3 = True in get_attention_backend
else: else:
from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3 from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_3.flash_attn_interface import ( from flash_attn_3.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment