Commit 6d172ab4 authored by 王敏's avatar 王敏
Browse files

1.临时解决gemm调用到blaslt问题

2.支持设置v1的chunked_prefill开关
parent fdda4d82
......@@ -423,6 +423,10 @@ class ModelConfig:
- "vllm" will use the vLLM model implementation.\n
- "transformers" will use the Transformers model implementation."""
enable_chunked_prefill: Optional[bool] = None
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
......@@ -452,6 +456,7 @@ class ModelConfig:
factors.append(self.rope_theta)
# hf_config can control how the model looks!
factors.append(self.hf_config.to_json_string())
factors.append(self.enable_chunked_prefill)
str_factors = str(factors)
assert_hashable(str_factors)
return hashlib.sha256(str(factors).encode()).hexdigest()
......
......@@ -956,6 +956,7 @@ class EngineArgs:
override_generation_config=self.override_generation_config,
enable_sleep_mode=self.enable_sleep_mode,
model_impl=self.model_impl,
enable_chunked_prefill=self.enable_chunked_prefill
)
def create_load_config(self) -> LoadConfig:
......@@ -1046,7 +1047,7 @@ class EngineArgs:
# Set default arguments for V0 or V1 Engine.
if use_v1:
self._set_default_args_v1(usage_context)
self._set_default_args_v1(usage_context, model_config)
else:
self._set_default_args_v0(model_config)
......@@ -1532,12 +1533,16 @@ class EngineArgs:
if self.max_num_seqs is None:
self.max_num_seqs = 256
def _set_default_args_v1(self, usage_context: UsageContext) -> None:
def _set_default_args_v1(self, usage_context: UsageContext, model_config: ModelConfig) -> None:
"""Set Default Arguments for V1 Engine."""
# V1 always uses chunked prefills.
self.enable_chunked_prefill = True
if model_config.enable_chunked_prefill is not None and \
model_config.enable_chunked_prefill is False:
self.enable_chunked_prefill = False
# V1 enables prefix caching by default.
if self.enable_prefix_caching is None:
self.enable_prefix_caching = True
......
......@@ -100,6 +100,15 @@ def with_amdsmi_context(fn):
return wrapper
def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
return device_id
@cache
def on_gfx1x() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
......@@ -396,13 +405,10 @@ class RocmPlatform(Platform):
@with_amdsmi_context
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = cls.device_id_to_physical_device_id(device_id)
physical_device_id = device_id_to_physical_device_id(device_id)
handle = amdsmi_get_processor_handles()[physical_device_id]
asic_info = amdsmi_get_gpu_asic_info(handle)
device_name: str = asic_info["device_id"]
if device_name in _ROCM_DEVICE_ID_NAME_MAP:
return _ROCM_DEVICE_ID_NAME_MAP[device_name]
return asic_info["market_name"]
# return amdsmi_get_gpu_asic_info(handle)["market_name"]
return torch.cuda.get_device_name(device_id)
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
......
......@@ -28,6 +28,7 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase,
ModelRunnerInputBase)
torch._C._set_blas_preferred_backend(torch._C._BlasBackend.Cublas)
logger = init_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment