Unverified Commit 4b1d141f authored by Yang Zheng's avatar Yang Zheng Committed by GitHub
Browse files

[PP] Correct cache size check (#13873)


Signed-off-by: default avatarYang Zheng <zhengy.gator@gmail.com>
parent 10c3b8c1
...@@ -258,9 +258,10 @@ class HPUWorker(LocalOrDistributedWorkerBase): ...@@ -258,9 +258,10 @@ class HPUWorker(LocalOrDistributedWorkerBase):
This also warms up the model, which may record CUDA graphs. This also warms up the model, which may record CUDA graphs.
""" """
raise_if_cache_size_invalid(num_gpu_blocks, raise_if_cache_size_invalid(
self.cache_config.block_size, num_gpu_blocks, self.cache_config.block_size,
self.model_config.max_model_len) self.model_config.max_model_len,
self.parallel_config.pipeline_parallel_size)
self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks
...@@ -442,13 +443,13 @@ def init_worker_distributed_environment( ...@@ -442,13 +443,13 @@ def init_worker_distributed_environment(
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len,
max_model_len) -> None: pipeline_parallel_size) -> None:
if num_gpu_blocks <= 0: if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. " raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when " "Try increasing `gpu_memory_utilization` when "
"initializing the engine.") "initializing the engine.")
max_seq_len = block_size * num_gpu_blocks max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size)
if max_model_len > max_seq_len: if max_model_len > max_seq_len:
raise ValueError( raise ValueError(
f"The model's max seq len ({max_model_len}) " f"The model's max seq len ({max_model_len}) "
......
...@@ -288,10 +288,11 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -288,10 +288,11 @@ class Worker(LocalOrDistributedWorkerBase):
This also warms up the model, which may record CUDA graphs. This also warms up the model, which may record CUDA graphs.
""" """
raise_if_cache_size_invalid(num_gpu_blocks, raise_if_cache_size_invalid(
self.cache_config.block_size, num_gpu_blocks, self.cache_config.block_size,
self.cache_config.is_attention_free, self.cache_config.is_attention_free,
self.model_config.max_model_len) self.model_config.max_model_len,
self.parallel_config.pipeline_parallel_size)
self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks
...@@ -530,7 +531,7 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): ...@@ -530,7 +531,7 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
max_model_len) -> None: max_model_len, pipeline_parallel_size) -> None:
if is_attention_free and num_gpu_blocks != 0: if is_attention_free and num_gpu_blocks != 0:
raise ValueError("No memory should be allocated for the cache blocks " raise ValueError("No memory should be allocated for the cache blocks "
f"for an attention-free model, but {num_gpu_blocks} " f"for an attention-free model, but {num_gpu_blocks} "
...@@ -539,7 +540,7 @@ def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, ...@@ -539,7 +540,7 @@ def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
raise ValueError("No available memory for the cache blocks. " raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when " "Try increasing `gpu_memory_utilization` when "
"initializing the engine.") "initializing the engine.")
max_seq_len = block_size * num_gpu_blocks max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size)
if not is_attention_free and max_model_len > max_seq_len: if not is_attention_free and max_model_len > max_seq_len:
raise ValueError( raise ValueError(
f"The model's max seq len ({max_model_len}) " f"The model's max seq len ({max_model_len}) "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment