Commit b3062dab authored by zhuwenwen's avatar zhuwenwen
Browse files

support fa kvcache fp8, add VLLM_USE_QUERY_QUANT to not use q quant(todo)

parent 4e51cae7
......@@ -27,7 +27,7 @@ static inline __device__ float fp8_to_float(uint8_t input) {
}
// float -> fp8
static inline __device__ uint8_t float_to_fp8(float f) {
static inline __device__ uint8_t float_to_fp8_e4m3(float f) {
constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
uint32_t f_bits = c10::detail::fp32_to_bits(f);
......@@ -53,10 +53,35 @@ static inline __device__ uint8_t float_to_fp8(float f) {
return result;
}
static inline __device__ uint8_t float_to_fp8_e5m2(float f) {
constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
constexpr uint32_t fp8_max = UINT32_C(143) << 23;
constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
uint32_t f_bits = c10::detail::fp32_to_bits(f);
uint8_t result = 0u;
const uint32_t sign = f_bits & UINT32_C(0x80000000);
f_bits ^= sign;
if (f_bits >= fp8_max) {
result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
} else {
if (f_bits < (UINT32_C(113) << 23)) {
f_bits = c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits)
+ c10::detail::fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
uint32_t mant_odd = (f_bits >> 21) & 1;
f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
f_bits += mant_odd;
result = static_cast<uint8_t>(f_bits >> 21);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
const float scale) {
const float scale, Fp8KVCacheDataType kv_type) {
return x;
}
......@@ -65,8 +90,10 @@ using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
return __float2bfloat16(fp8_to_float(a) * scale);
}
......@@ -74,32 +101,32 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
template <>
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
float scale) {
float scale, Fp8KVCacheDataType kv_type) {
__nv_bfloat162 res;
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, kv_type);
res.y =
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return res;
}
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
bf16_4_t res;
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, kv_type);
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
scale);
scale, kv_type);
return res;
}
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, kv_type);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, kv_type);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
......@@ -111,24 +138,27 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
// fp8 -> float
template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t& a, float scale) {
const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
return fp8_to_float(a) * scale;
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
float2 f2r;
f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale);
f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale);
f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale, kv_type);
f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return f2r;
}
// fp8x4 -> float4
template <>
__inline__ __device__ Float4_
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale, Fp8KVCacheDataType kv_type) {
Float4_ res;
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
......@@ -138,18 +168,18 @@ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
// fp8x4 -> float4
template <>
__inline__ __device__ float4
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale, kv_type) {
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale, kv_type);
return {res.x.x, res.x.y, res.y.x, res.y.y};
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, kv_type);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, kv_type);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
......@@ -161,7 +191,10 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
float res = fp8_to_float(a) * scale;
return float_to_half(res);
}
......@@ -169,54 +202,58 @@ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint16_t u16[2];
uint32_t u32;
} res;
res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale);
res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale);
res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale, kv_type);
res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return res.u32;
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, kv_type);
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale, kv_type);
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
float scale) {
float scale, Fp8KVCacheDataType kv_type) {
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, kv_type);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, kv_type);
return tmp.u64x2;
}
// half -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
float res_f = half_to_float(a) / scale;
return float_to_fp8(res_f);
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(res_f);
} else {
return float_to_fp8_e5m2(res_f);
}
}
// halfx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint8_t ui8[2];
uint16_t ui16;
......@@ -226,113 +263,122 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
half2 h2r;
} tmp_a;
tmp_a.ui32 = a;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale);
tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale, kv_type);
return tmp.ui16;
}
// half2x2 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale, kv_type);
return tmp.ui32;
}
// half2x4 -> fp8x8
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
float scale) {
float scale, Fp8KVCacheDataType kv_type) {
union {
uint2 ui2[2];
uint4 ui4;
} tmp;
tmp.ui4 = a;
uint2 res;
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale, kv_type);
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale, kv_type);
return res;
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16& a, float scale) {
const __nv_bfloat16& a, float scale, Fp8KVCacheDataType kv_type) {
float res_f = (static_cast<float>(a)) / scale;
return float_to_fp8(res_f);
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(res_f);
} else {
return float_to_fp8_e5m2(res_f);
}
}
// bf16x2 -> fp8x2
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
const __nv_bfloat162& a, float scale) {
const __nv_bfloat162& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint8_t ui8[2];
uint16_t ui16;
} tmp;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale, kv_type);
return tmp.ui16;
}
// bf16x4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale, kv_type);
return tmp.ui32;
}
// bf16x8 -> fp8x8
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale, Fp8KVCacheDataType kv_type) {
uint2 res;
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale, kv_type);
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale, kv_type);
return res;
}
// float -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
return float_to_fp8(a / scale);
scaled_vec_conversion<uint8_t, float>(const float& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(a / scale);
} else {
return float_to_fp8_e5m2(a / scale);
}
}
// floatx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint8_t ui8[2];
uint16_t ui16;
} tmp;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale);
tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale, kv_type);
return tmp.ui16;
}
// floatx4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale, kv_type);
return tmp.ui32;
}
......@@ -433,8 +479,8 @@ scaled_vec_conversion_from_e5m2<__nv_bfloat16>(const uint8_t& a, float scale) {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale);
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3 || kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return scaled_vec_conversion<Tout, Tin>(x, scale, kv_dt);
}
else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tout)==1){
return scaled_vec_conversion_to_e5m2<Tin>(x, scale);
......
......@@ -5,6 +5,7 @@ from typing import Optional
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
import torch
logger = init_logger(__name__)
......@@ -68,6 +69,8 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
def flash_attn_supports_fp8() -> bool:
if current_platform.is_rocm():
return True
return get_flash_attn_version() == 3 and \
current_platform.get_device_capability().major == 9
......
......@@ -149,6 +149,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_FLASH_ATTN_FP8: bool = False
VLLM_USE_QUERY_QUANT: bool = False
VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_FLASH_MLA_FP8: bool = False
VLLM_USE_OPT_OP: bool = False
......@@ -1071,6 +1072,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_ATTN_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_ATTN_FP8", "1"))),
# flag to control if vllm should use q quant
"VLLM_USE_QUERY_QUANT":
lambda: (os.environ.get("VLLM_USE_QUERY_QUANT", "False").lower() in
("true", "1")),
# If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))),
......
......@@ -136,6 +136,17 @@ class FlashAttentionBackend(AttentionBackend):
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return key_stride_order, value_stride_order
@staticmethod
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return torch.float8_e4m3fn
else:
raise ValueError(f"Unsupported FP8 dtype: {kv_cache_dtype}")
elif kv_cache_dtype in ("fp8_e5m2"):
return torch.float8_e5m2
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@dataclass
class FlashAttentionMetadata:
......@@ -589,14 +600,19 @@ class FlashAttentionImpl(AttentionImpl):
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# key_cache = key_cache.view(torch.float8_e4m3fn)
# value_cache = value_cache.view(torch.float8_e4m3fn)
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype)
key_cache = key_cache.view(dtype)
value_cache = value_cache.view(dtype)
if envs.VLLM_USE_QUERY_QUANT:
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# Compute attention and update output up to `num_actual_tokens`.
use_local_attn = \
......@@ -620,9 +636,10 @@ class FlashAttentionImpl(AttentionImpl):
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
# descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
if not current_platform.is_rocm():
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
......@@ -672,6 +689,9 @@ class FlashAttentionImpl(AttentionImpl):
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
q_descale=None,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
# num_splits=attn_metadata.max_num_splits,
is_prefix_cache=True,
)
......@@ -729,6 +749,9 @@ class FlashAttentionImpl(AttentionImpl):
# q_descale=layer._q_scale,
# k_descale=layer._k_scale,
# v_descale=layer._v_scale,
q_descale=None,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
)
return output
......@@ -879,12 +902,12 @@ def cascade_attention(
return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None,
is_prefix_cache=True,
)
......@@ -932,12 +955,12 @@ def cascade_attention(
return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None,
is_prefix_cache=True,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment