Commit a5aa55e8 authored by zhuwenwen's avatar zhuwenwen
Browse files

support Z100L and K100 inference

parent 5aa6d7c2
...@@ -38,7 +38,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") ...@@ -38,7 +38,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
# Supported AMD GPU architectures. # Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx928;gfx936") set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx906;gfx926;gfx928;gfx936")
# #
# Supported/expected torch versions for CUDA/ROCm. # Supported/expected torch versions for CUDA/ROCm.
......
...@@ -22,6 +22,7 @@ from vllm.logger import init_logger ...@@ -22,6 +22,7 @@ from vllm.logger import init_logger
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.utils import SUPPORT_TC
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
...@@ -544,7 +545,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -544,7 +545,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
f"Head size {head_size} is not supported by PagedAttention. " f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.") f"Supported head sizes are: {supported_head_sizes}.")
self.use_naive_attn = False if SUPPORT_TC:
self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton. # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn: if self.use_triton_flash_attn:
...@@ -572,15 +574,18 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -572,15 +574,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if not current_platform.has_device_capability(90): if not current_platform.has_device_capability(90):
self.use_naive_attn = True self.use_naive_attn = True
else: else:
try: if SUPPORT_TC:
from flash_attn import flash_attn_varlen_func # noqa: F401 try:
self.fa_attn_func = flash_attn_varlen_func from flash_attn import flash_attn_varlen_func # noqa: F401
if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN: self.fa_attn_func = flash_attn_varlen_func
from flash_attn import vllm_flash_attn_varlen_func if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN:
self.fa_prefix_attn_func = vllm_flash_attn_varlen_func from flash_attn import vllm_flash_attn_varlen_func
self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
logger.debug("Using CUTLASS FA in ROCmBackend")
except ModuleNotFoundError: logger.debug("Using CUTLASS FA in ROCmBackend")
except ModuleNotFoundError:
self.use_naive_attn = True
else:
self.use_naive_attn = True self.use_naive_attn = True
if self.use_naive_attn: if self.use_naive_attn:
......
...@@ -13,13 +13,12 @@ from vllm.attention.backends.mla.common import (MLACommonBackend, ...@@ -13,13 +13,12 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl, MLACommonImpl,
MLACommonMetadata) MLACommonMetadata)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
class TritonMLABackend(MLACommonBackend): class TritonMLABackend(MLACommonBackend):
@staticmethod @staticmethod
...@@ -67,7 +66,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -67,7 +66,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
"are not implemented for " "are not implemented for "
"TritonMLAImpl") "TritonMLAImpl")
self.attn_configs = get_attention_mla_configs_json(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16") if envs.VLLM_USE_TRITON_OPT_MLA:
self.attn_configs = get_attention_mla_configs_json(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError( raise NotImplementedError(
......
...@@ -8,15 +8,14 @@ import torch ...@@ -8,15 +8,14 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
import vllm.envs as envs import vllm.envs as envs
from vllm.utils import SUPPORT_TC
if HAS_TRITON: if HAS_TRITON:
from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.prefix_prefill import context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
gpuname = torch.cuda.get_device_properties(torch.cuda.current_device()).name use_tc = envs.VLLM_USE_OPT_OP and envs.VLLM_USE_TC_PAGED_ATTN and SUPPORT_TC
support_tc = gpuname.startswith('K100_AI') or gpuname.startswith('BW')
use_tc = envs.VLLM_USE_OPT_OP and envs.VLLM_USE_TC_PAGED_ATTN and support_tc
@dataclass @dataclass
class PagedAttentionMetadata: class PagedAttentionMetadata:
......
...@@ -44,6 +44,7 @@ from vllm.transformers_utils.utils import is_s3, maybe_model_redirect ...@@ -44,6 +44,7 @@ from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, get_open_port, is_torch_equal_or_newer, get_cpu_memory, get_open_port, is_torch_equal_or_newer,
random_uuid, resolve_obj_by_qualname) random_uuid, resolve_obj_by_qualname)
from vllm.utils import SUPPORT_TC
if TYPE_CHECKING: if TYPE_CHECKING:
from _typeshed import DataclassInstance from _typeshed import DataclassInstance
...@@ -1262,7 +1263,7 @@ class ModelConfig: ...@@ -1262,7 +1263,7 @@ class ModelConfig:
@property @property
def use_mla(self) -> bool: def use_mla(self) -> bool:
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE and SUPPORT_TC
@property @property
def supported_runner_types(self) -> set[RunnerType]: def supported_runner_types(self) -> set[RunnerType]:
......
...@@ -74,6 +74,9 @@ if TYPE_CHECKING: ...@@ -74,6 +74,9 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
gpuname = torch.cuda.get_device_properties(torch.cuda.current_device()).name
SUPPORT_TC = gpuname.startswith('K100_AI') or gpuname.startswith('BW')
# Exception strings for non-implemented encoder/decoder scenarios # Exception strings for non-implemented encoder/decoder scenarios
# Reminder: Please update docs/source/features/compatibility_matrix.md # Reminder: Please update docs/source/features/compatibility_matrix.md
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment