Commit a5aa55e8 authored by zhuwenwen's avatar zhuwenwen
Browse files

support Z100L and K100 inference

parent 5aa6d7c2
......@@ -38,7 +38,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx928;gfx936")
set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx906;gfx926;gfx928;gfx936")
#
# Supported/expected torch versions for CUDA/ROCm.
......
......@@ -22,6 +22,7 @@ from vllm.logger import init_logger
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.utils import SUPPORT_TC
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
......@@ -544,6 +545,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")
if SUPPORT_TC:
self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
......@@ -572,6 +574,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if not current_platform.has_device_capability(90):
self.use_naive_attn = True
else:
if SUPPORT_TC:
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.fa_attn_func = flash_attn_varlen_func
......@@ -582,6 +585,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
logger.debug("Using CUTLASS FA in ROCmBackend")
except ModuleNotFoundError:
self.use_naive_attn = True
else:
self.use_naive_attn = True
if self.use_naive_attn:
if logits_soft_cap is not None:
......
......@@ -13,13 +13,12 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
class TritonMLABackend(MLACommonBackend):
@staticmethod
......@@ -67,6 +66,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
"are not implemented for "
"TritonMLAImpl")
if envs.VLLM_USE_TRITON_OPT_MLA:
self.attn_configs = get_attention_mla_configs_json(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
if is_quantized_kv_cache(self.kv_cache_dtype):
......
......@@ -8,15 +8,14 @@ import torch
from vllm import _custom_ops as ops
from vllm.triton_utils import HAS_TRITON
import vllm.envs as envs
from vllm.utils import SUPPORT_TC
if HAS_TRITON:
from vllm.attention.ops.prefix_prefill import context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
gpuname = torch.cuda.get_device_properties(torch.cuda.current_device()).name
support_tc = gpuname.startswith('K100_AI') or gpuname.startswith('BW')
use_tc = envs.VLLM_USE_OPT_OP and envs.VLLM_USE_TC_PAGED_ATTN and support_tc
use_tc = envs.VLLM_USE_OPT_OP and envs.VLLM_USE_TC_PAGED_ATTN and SUPPORT_TC
@dataclass
class PagedAttentionMetadata:
......
......@@ -44,6 +44,7 @@ from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
random_uuid, resolve_obj_by_qualname)
from vllm.utils import SUPPORT_TC
if TYPE_CHECKING:
from _typeshed import DataclassInstance
......@@ -1262,7 +1263,7 @@ class ModelConfig:
@property
def use_mla(self) -> bool:
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE and SUPPORT_TC
@property
def supported_runner_types(self) -> set[RunnerType]:
......
......@@ -74,6 +74,9 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
gpuname = torch.cuda.get_device_properties(torch.cuda.current_device()).name
SUPPORT_TC = gpuname.startswith('K100_AI') or gpuname.startswith('BW')
# Exception strings for non-implemented encoder/decoder scenarios
# Reminder: Please update docs/source/features/compatibility_matrix.md
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment