Unverified Commit 6db31e7a authored by Akash kaothalkar's avatar Akash kaothalkar Committed by GitHub
Browse files

[Hardware][PPC64LE] Enable V1 for ppc64le and ARM (#20554)


Signed-off-by: default avatarAkash Kaothalkar <akash.kaothalkar@ibm.com>
Co-authored-by: default avatarAkash Kaothalkar <akash.kaothalkar@ibm.com>
Co-authored-by: default avatarNikhil Gupta <nikhil.gupta2@arm.com>
parent 977180c9
...@@ -36,6 +36,7 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, ...@@ -36,6 +36,7 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
from vllm.reasoning import ReasoningParserManager from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
...@@ -1096,7 +1097,6 @@ class EngineArgs: ...@@ -1096,7 +1097,6 @@ class EngineArgs:
If VLLM_USE_V1 is specified by the user but the VllmConfig If VLLM_USE_V1 is specified by the user but the VllmConfig
is incompatible, we raise an error. is incompatible, we raise an error.
""" """
from vllm.platforms import current_platform
current_platform.pre_register_and_update() current_platform.pre_register_and_update()
device_config = DeviceConfig( device_config = DeviceConfig(
...@@ -1123,9 +1123,16 @@ class EngineArgs: ...@@ -1123,9 +1123,16 @@ class EngineArgs:
# Set default arguments for V0 or V1 Engine. # Set default arguments for V0 or V1 Engine.
if use_v1: if use_v1:
self._set_default_args_v1(usage_context, model_config) self._set_default_args_v1(usage_context, model_config)
# Disable chunked prefill for POWER (ppc64le)/ARM CPUs in V1
if current_platform.is_cpu(
) and current_platform.get_cpu_architecture() in (
CpuArchEnum.POWERPC, CpuArchEnum.ARM):
logger.info(
"Chunked prefill is not supported for ARM and POWER CPUs; "
"disabling it for V1 backend.")
self.enable_chunked_prefill = False
else: else:
self._set_default_args_v0(model_config) self._set_default_args_v0(model_config)
assert self.enable_chunked_prefill is not None assert self.enable_chunked_prefill is not None
if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]: if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
...@@ -1242,7 +1249,6 @@ class EngineArgs: ...@@ -1242,7 +1249,6 @@ class EngineArgs:
if self.enable_chunked_prefill and self.pipeline_parallel_size > 1: if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
raise ValueError("Multi-Step Chunked-Prefill is not supported " raise ValueError("Multi-Step Chunked-Prefill is not supported "
"for pipeline-parallel-size > 1") "for pipeline-parallel-size > 1")
from vllm.platforms import current_platform
if current_platform.is_cpu(): if current_platform.is_cpu():
logger.warning("Multi-Step (--num-scheduler-steps > 1) is " logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
"currently not supported for CPUs and has been " "currently not supported for CPUs and has been "
...@@ -1391,7 +1397,6 @@ class EngineArgs: ...@@ -1391,7 +1397,6 @@ class EngineArgs:
# Skip this check if we are running on a non-GPU platform, # Skip this check if we are running on a non-GPU platform,
# or if the device capability is not available # or if the device capability is not available
# (e.g. in a Ray actor without GPUs). # (e.g. in a Ray actor without GPUs).
from vllm.platforms import current_platform
if (current_platform.is_cuda() if (current_platform.is_cuda()
and current_platform.get_device_capability() and current_platform.get_device_capability()
and current_platform.get_device_capability().major < 8): and current_platform.get_device_capability().major < 8):
...@@ -1652,7 +1657,6 @@ class EngineArgs: ...@@ -1652,7 +1657,6 @@ class EngineArgs:
# as the platform that vLLM is running on (e.g. the case of scaling # as the platform that vLLM is running on (e.g. the case of scaling
# vLLM with Ray) and has no GPUs. In this case we use the default # vLLM with Ray) and has no GPUs. In this case we use the default
# values for non-H100/H200 GPUs. # values for non-H100/H200 GPUs.
from vllm.platforms import current_platform
try: try:
device_memory = current_platform.get_device_total_memory() device_memory = current_platform.get_device_total_memory()
device_name = current_platform.get_device_name().lower() device_name = current_platform.get_device_name().lower()
...@@ -1755,7 +1759,6 @@ class AsyncEngineArgs(EngineArgs): ...@@ -1755,7 +1759,6 @@ class AsyncEngineArgs(EngineArgs):
parser.add_argument('--disable-log-requests', parser.add_argument('--disable-log-requests',
action='store_true', action='store_true',
help='Disable logging requests.') help='Disable logging requests.')
from vllm.platforms import current_platform
current_platform.pre_register_and_update(parser) current_platform.pre_register_and_update(parser)
return parser return parser
......
...@@ -271,5 +271,6 @@ class CpuPlatform(Platform): ...@@ -271,5 +271,6 @@ class CpuPlatform(Platform):
"""Returns whether the current platform can use v1 by default for the """Returns whether the current platform can use v1 by default for the
supplied model configuration. supplied model configuration.
""" """
return cls.supports_v1( arch = cls.get_cpu_architecture()
model_config) and cls.get_cpu_architecture() == CpuArchEnum.X86 return (cls.supports_v1(model_config) and arch
in (CpuArchEnum.X86, CpuArchEnum.POWERPC, CpuArchEnum.ARM))
...@@ -316,7 +316,6 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): ...@@ -316,7 +316,6 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
block_table: BlockTable) -> None: block_table: BlockTable) -> None:
self.runner = runner self.runner = runner
self.block_table = block_table self.block_table = block_table
# For reorder # For reorder
self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs, self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
dtype=np.int64) dtype=np.int64)
...@@ -401,11 +400,14 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): ...@@ -401,11 +400,14 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
# to ensure inference when chunked_prefill is disabled
seq_lens=runner.seq_lens_cpu[:num_reqs].tolist(),
seq_lens_tensor=runner. seq_lens_tensor=runner.
seq_lens_cpu[num_prompt_req:num_reqs], # decode seq_lens_cpu[num_prompt_req:num_reqs], # decode
max_decode_seq_len=max_decode_seq_len, # decode max_decode_seq_len=max_decode_seq_len, # decode
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
chunked_prefill=True, chunked_prefill=self.runner.scheduler_config.
chunked_prefill_enabled,
max_query_len=max_query_len, max_query_len=max_query_len,
max_kv_len=max_prefill_seq_len, max_kv_len=max_prefill_seq_len,
prefill_query_start_loc=runner. prefill_query_start_loc=runner.
......
...@@ -11,7 +11,7 @@ from vllm.config import VllmConfig ...@@ -11,7 +11,7 @@ from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.platforms import current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
...@@ -43,8 +43,12 @@ class CPUWorker(Worker): ...@@ -43,8 +43,12 @@ class CPUWorker(Worker):
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
self.local_omp_cpuid = "all" self.local_omp_cpuid = "all"
if omp_cpuids == "auto": if omp_cpuids == "auto":
self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes( if current_platform.get_cpu_architecture() == CpuArchEnum.POWERPC:
) self.local_omp_cpuid = (
self.get_cpus_id_binding_based_on_numa_nodes_ppc64le())
else:
self.local_omp_cpuid = (
self.get_cpus_id_binding_based_on_numa_nodes())
else: else:
self.local_omp_cpuid = omp_cpuids.split("|")[self.rank] self.local_omp_cpuid = omp_cpuids.split("|")[self.rank]
...@@ -153,3 +157,57 @@ class CPUWorker(Worker): ...@@ -153,3 +157,57 @@ class CPUWorker(Worker):
"fallback to no thread-binding. To get better performance," "fallback to no thread-binding. To get better performance,"
"please try to manually bind threads.") "please try to manually bind threads.")
return rank_to_cpus return rank_to_cpus
def get_cpus_id_binding_based_on_numa_nodes_ppc64le(self) -> str:
"""
Power (ppc64le) specific: Selects a subset of threads per core for
each NUMA node.This is robust to SMT mode (SMT-8, SMT-4, etc)
because the OS only exposes available threads.This maximizes
performance by avoiding oversubscription of logical CPUs on Power.
"""
def select_threads_per_power_core(node_cpu_ids):
return [cpu for cpu in node_cpu_ids if cpu % 8 < 4]
rank_to_cpus = self.local_omp_cpuid
world_size = self.vllm_config.parallel_config.world_size
libnuma_found = util.find_spec("numa") is not None
psutil_found = util.find_spec("psutil") is not None
if libnuma_found and psutil_found:
import psutil
from numa import info
cpus_allow_list = psutil.Process().cpu_affinity()
numa_size = info.get_num_configured_nodes()
node_to_cpus = []
for i in range(numa_size):
node_intersect = set(
info.node_to_cpus(i)).intersection(cpus_allow_list)
if bool(node_intersect):
node_to_cpus.append(sorted(list(node_intersect)))
if world_size > len(node_to_cpus):
logger.error(
"Auto thread-binding failed due to "
"world size: %d is larger than "
"allowed NUMA nodes number: %d."
"Please try to bind threads manually.", world_size,
len(node_to_cpus))
else:
node_cpus_this_rank = node_to_cpus[self.rank]
node_cpus_this_rank = select_threads_per_power_core(
node_cpus_this_rank)
cpu_count_per_numa = len(node_cpus_this_rank)
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
cpu_count_per_numa // 2)
end = cpu_count_per_numa - num_of_reserved_cpu
rank_to_cpus_list = node_cpus_this_rank[:end]
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
logger.info("ppc64le thread-binding list: %s", rank_to_cpus)
else:
logger.warning(
"Auto thread-binding is not supported due to "
"the lack of package numa and psutil,"
"fallback to no thread-binding. To get better performance,"
"please try to manually bind threads.")
return rank_to_cpus
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment