Unverified Commit 01413e0c authored by xiao-llm's avatar xiao-llm Committed by GitHub
Browse files

Fp8 paged attention update (#22222)


Signed-off-by: default avatarXiao Yu <xiao.yu@amd.com>
Signed-off-by: default avatarxiao-llm <xiao.yu.dc@outlook.com>
Co-authored-by: default avatarXiao Yu <xiao.yu@metamaterial.com>
Co-authored-by: default avatarXiao Yu <xiao.yu@amd.com>
Co-authored-by: default avatarBowen Bao <bowenbao@amd.com>
parent 0e219cd5
This diff is collapsed.
......@@ -19,4 +19,5 @@ void paged_attention(
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale);
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale,
const std::string& mfma_type);
......@@ -48,7 +48,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale,"
" Tensor? fp8_out_scale) -> ()");
" Tensor? fp8_out_scale,"
" str mfma_type) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
}
......
......@@ -117,13 +117,14 @@ def paged_attention_rocm(
k_scale: torch.Tensor,
v_scale: torch.Tensor,
fp8_out_scale: Optional[torch.Tensor] = None,
mfma_type: str = "fp8" if envs.VLLM_ROCM_FP8_MFMA_PAGE_ATTN else "f16",
) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens,
query_start_loc, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale,
v_scale, fp8_out_scale)
v_scale, fp8_out_scale, mfma_type)
def mla_decode_kvcache_cpu(
......
......@@ -167,6 +167,7 @@ if TYPE_CHECKING:
VLLM_HAS_FLASHINFER_CUBIN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
......@@ -1219,6 +1220,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_RESPONSES_API_STORE":
lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))),
# If set, use the fp8 mfma in rocm paged attention.
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_MFMA_PAGE_ATTN", "0"))),
# Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
......@@ -1340,6 +1345,7 @@ def compute_hash() -> str:
"VLLM_ROCM_QUICK_REDUCE_QUANTIZATION",
"VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16",
"VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB",
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
]
for key in environment_variables_to_hash:
# if this goes out of sync with environment_variables,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment