Unverified Commit 464dd985 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Fix num_gpus when TP > 1 (#1852)

parent c07a4428
......@@ -301,7 +301,16 @@ class AsyncLLMEngine:
elif self.worker_use_ray:
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
else:
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
# order of the arguments.
cache_config = args[1]
parallel_config = args[2]
if parallel_config.tensor_parallel_size == 1:
num_gpus = cache_config.gpu_memory_utilization
else:
num_gpus = 1
engine_class = ray.remote(num_gpus=num_gpus)(
self._engine_class).remote
return engine_class(*args, **kwargs)
async def engine_step(self) -> bool:
......
......@@ -159,9 +159,13 @@ class LLMEngine:
for bundle in placement_group.bundle_specs:
if not bundle.get("GPU", 0):
continue
if self.parallel_config.tensor_parallel_size == 1:
num_gpus = self.cache_config.gpu_memory_utilization
else:
num_gpus = 1
worker = ray.remote(
num_cpus=0,
num_gpus=self.cache_config.gpu_memory_utilization,
num_gpus=num_gpus,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment