Unverified Commit fbf91daf authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix(sglang): use correct API and deps for memory occupation endpoints (#5635)

parent f976f02b
...@@ -70,6 +70,11 @@ class BaseWorkerHandler(ABC): ...@@ -70,6 +70,11 @@ class BaseWorkerHandler(ABC):
2. Pause generation - drain in-flight requests 2. Pause generation - drain in-flight requests
3. Release memory - safe now that no requests are active 3. Release memory - safe now that no requests are active
""" """
from sglang.srt.managers.io_struct import (
PauseGenerationReqInput,
ReleaseMemoryOccupationReqInput,
)
tags = body.get("tags", body.get("tag", None)) tags = body.get("tags", body.get("tag", None))
if tags is None: if tags is None:
tags = ["kv_cache", "weights", "cuda_graph"] tags = ["kv_cache", "weights", "cuda_graph"]
...@@ -84,10 +89,14 @@ class BaseWorkerHandler(ABC): ...@@ -84,10 +89,14 @@ class BaseWorkerHandler(ABC):
) )
# Step 2: Pause generation to drain in-flight requests # Step 2: Pause generation to drain in-flight requests
await self.engine.async_pause_generation() pause_req = PauseGenerationReqInput()
await self.engine.tokenizer_manager.pause_generation(pause_req)
# Step 3: Release memory now that it's safe # Step 3: Release memory now that it's safe
await self.engine.async_release_memory_occupation(tags) release_req = ReleaseMemoryOccupationReqInput(tags=tags)
await self.engine.tokenizer_manager.release_memory_occupation(
release_req, None
)
return { return {
"status": "ok", "status": "ok",
...@@ -109,16 +118,25 @@ class BaseWorkerHandler(ABC): ...@@ -109,16 +118,25 @@ class BaseWorkerHandler(ABC):
2. Continue generation - ready to serve requests 2. Continue generation - ready to serve requests
3. Re-register to discovery - allow frontend to route here 3. Re-register to discovery - allow frontend to route here
""" """
from sglang.srt.managers.io_struct import (
ContinueGenerationReqInput,
ResumeMemoryOccupationReqInput,
)
tags = body.get("tags", body.get("tag", None)) tags = body.get("tags", body.get("tag", None))
if tags is None: if tags is None:
tags = ["kv_cache", "weights", "cuda_graph"] tags = ["kv_cache", "weights", "cuda_graph"]
try: try:
# Step 1: Resume memory first - must be ready before accepting requests # Step 1: Resume memory first - must be ready before accepting requests
await self.engine.async_resume_memory_occupation(tags) resume_req = ResumeMemoryOccupationReqInput(tags=tags)
await self.engine.tokenizer_manager.resume_memory_occupation(
resume_req, None
)
# Step 2: Continue generation # Step 2: Continue generation
await self.engine.async_continue_generation() continue_req = ContinueGenerationReqInput()
await self.engine.tokenizer_manager.continue_generation(continue_req)
# Step 3: Re-register to discovery so frontend can route to us # Step 3: Re-register to discovery so frontend can route to us
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment