Unverified Commit dd650e0e authored by Zilin Zhu's avatar Zilin Zhu Committed by GitHub
Browse files

[RL] fix skip_server_warmup and rl health_generate logic (#8757)

parent a9471542
......@@ -1172,6 +1172,8 @@ def _wait_and_warmup(
pipe_finish_writer,
):
return
else:
_global_state.tokenizer_manager.server_status = ServerStatus.Up
logger.info("The server is fired up and ready to roll!")
......
......@@ -473,6 +473,7 @@ class Scheduler(
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
self.offload_tags = set()
self.init_profier()
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
......@@ -1040,7 +1041,9 @@ class Scheduler(
for recv_req in recv_reqs:
# If it is a health check generation request and there are running requests, ignore it.
if is_health_check_generate_req(recv_req) and (
self.chunked_req is not None or not self.running_batch.is_empty()
self.chunked_req is not None
or not self.running_batch.is_empty()
or len(self.offload_tags) > 0
):
self.return_health_check_ct += 1
continue
......
......@@ -78,6 +78,9 @@ class SchedulerUpdateWeightsMixin:
if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
for tag in tags:
self.offload_tags.add(tag)
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
self.flush_cache()
......@@ -97,6 +100,9 @@ class SchedulerUpdateWeightsMixin:
if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
for tag in tags:
self.offload_tags.remove(tag)
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
torch.distributed.barrier(self.tp_cpu_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment