Commit 43a52016 authored by zhuwenwen's avatar zhuwenwen
Browse files

update rocm.py

parent dcec1db7
...@@ -140,7 +140,6 @@ class RocmPlatform(Platform): ...@@ -140,7 +140,6 @@ class RocmPlatform(Platform):
kv_cache_dtype, block_size, use_v1, kv_cache_dtype, block_size, use_v1,
use_mla) -> str: use_mla) -> str:
if use_mla: if use_mla:
<<<<<<< HEAD
if selected_backend == _Backend.TRITON_MLA or block_size != 64: if selected_backend == _Backend.TRITON_MLA or block_size != 64:
if use_v1: if use_v1:
logger.info_once("Using Triton MLA backend on V1 engine.") logger.info_once("Using Triton MLA backend on V1 engine.")
...@@ -174,40 +173,38 @@ class RocmPlatform(Platform): ...@@ -174,40 +173,38 @@ class RocmPlatform(Platform):
"flashmla.FlashMLABackend") "flashmla.FlashMLABackend")
else: else:
logger.info("Using Triton MLA backend (block size 64).") logger.info("Using Triton MLA backend (block size 64).")
return "vllm.attention.backends.triton_mla.TritonMLABackend" return "vllm.attention.backends.triton_mla.TritonMLABackend"
======= # from vllm.attention.backends.rocm_aiter_mla import (
from vllm.attention.backends.rocm_aiter_mla import ( # is_aiter_mla_enabled)
is_aiter_mla_enabled)
# if selected_backend is None:
if selected_backend is None: # selected_backend = (_Backend.ROCM_AITER_MLA if
selected_backend = (_Backend.ROCM_AITER_MLA if # is_aiter_mla_enabled() or block_size == 1
is_aiter_mla_enabled() or block_size == 1 # else _Backend.TRITON_MLA)
else _Backend.TRITON_MLA)
# if selected_backend == _Backend.TRITON_MLA:
if selected_backend == _Backend.TRITON_MLA: # if block_size != 1:
if block_size != 1: # logger.info("Using Triton MLA backend.")
logger.info("Using Triton MLA backend.") # return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501
return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 # else:
else: # raise ValueError(
raise ValueError( # f" The selected backend, {selected_backend.name},"
f" The selected backend, {selected_backend.name}," # f"does not support block size {block_size}.")
f"does not support block size {block_size}.") # elif selected_backend == _Backend.ROCM_AITER_MLA:
elif selected_backend == _Backend.ROCM_AITER_MLA: # if block_size == 1:
if block_size == 1: # logger.info("Using AITER MLA backend.")
logger.info("Using AITER MLA backend.") # return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501
return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 # else:
else: # raise ValueError(
raise ValueError( # f" The selected backend, {selected_backend.name},"
f" The selected backend, {selected_backend.name}," # f"does not support block size {block_size}."
f"does not support block size {block_size}." # "(currently only supports block size 1)")
"(currently only supports block size 1)") # else:
else:
raise ValueError( raise ValueError(
f" The selected backend, {selected_backend.name}," f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend.") f"is not MLA type while requested for MLA backend.")
>>>>>>> v0.8.5
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend) == _Backend.FLASH_ATTN else selected_backend)
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
...@@ -384,4 +381,4 @@ class RocmPlatform(Platform): ...@@ -384,4 +381,4 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def get_cu_count(cls, device_id: int = 0) -> int: def get_cu_count(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties( return torch.cuda.get_device_properties(
device_id).multi_processor_count device_id).multi_processor_count
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment