Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
c09a9aad
Unverified
Commit
c09a9aad
authored
Mar 06, 2026
by
Schwinn Saereesitthipitak
Committed by
GitHub
Mar 06, 2026
Browse files
fix: guard SGLang/vLLM memory occupation control endpoints (#6967)
parent
9b2b44e3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
400 additions
and
96 deletions
+400
-96
components/src/dynamo/sglang/request_handlers/handler_base.py
...onents/src/dynamo/sglang/request_handlers/handler_base.py
+75
-56
components/src/dynamo/sglang/tests/test_sglang_memory_occupation_handlers.py
...mo/sglang/tests/test_sglang_memory_occupation_handlers.py
+170
-0
components/src/dynamo/vllm/handlers.py
components/src/dynamo/vllm/handlers.py
+52
-40
components/src/dynamo/vllm/tests/test_vllm_sleep_wake_handlers.py
...ts/src/dynamo/vllm/tests/test_vllm_sleep_wake_handlers.py
+103
-0
No files found.
components/src/dynamo/sglang/request_handlers/handler_base.py
View file @
c09a9aad
...
@@ -19,6 +19,10 @@ from dynamo.runtime import DistributedRuntime
...
@@ -19,6 +19,10 @@ from dynamo.runtime import DistributedRuntime
from
dynamo.sglang.args
import
Config
from
dynamo.sglang.args
import
Config
from
dynamo.sglang.publisher
import
DynamoSglangPublisher
from
dynamo.sglang.publisher
import
DynamoSglangPublisher
# Keep default tags minimal and safe for general use.
# "cuda_graph" can still be requested explicitly, but it requires LD_PRELOAD setup.
DEFAULT_MEMORY_OCCUPATION_TAGS
=
[
"kv_cache"
,
"weights"
]
class
BaseGenerativeHandler
(
ABC
):
class
BaseGenerativeHandler
(
ABC
):
"""Minimal base class for all generative handlers (LLM, diffusion, etc.).
"""Minimal base class for all generative handlers (LLM, diffusion, etc.).
...
@@ -144,6 +148,8 @@ class BaseWorkerHandler(BaseGenerativeHandler):
...
@@ -144,6 +148,8 @@ class BaseWorkerHandler(BaseGenerativeHandler):
# have an sgl.Engine.
# have an sgl.Engine.
self
.
input_param_manager
=
InputParamManager
(
None
)
self
.
input_param_manager
=
InputParamManager
(
None
)
self
.
_engine_supports_priority
=
False
self
.
_engine_supports_priority
=
False
self
.
_memory_occupation_lock
=
asyncio
.
Lock
()
self
.
_memory_released
=
False
def
_priority_kwargs
(
self
,
priority
:
Any
)
->
Dict
[
str
,
Any
]:
def
_priority_kwargs
(
self
,
priority
:
Any
)
->
Dict
[
str
,
Any
]:
if
priority
is
not
None
and
self
.
_engine_supports_priority
:
if
priority
is
not
None
and
self
.
_engine_supports_priority
:
...
@@ -154,8 +160,7 @@ class BaseWorkerHandler(BaseGenerativeHandler):
...
@@ -154,8 +160,7 @@ class BaseWorkerHandler(BaseGenerativeHandler):
"""Release GPU memory occupation and unregister from discovery.
"""Release GPU memory occupation and unregister from discovery.
Args:
Args:
body: Dict with optional 'tags' key for which memory to release.
body: Unused. Release always targets default tags.
Default: ["kv_cache", "weights", "cuda_graph"]
Order of operations:
Order of operations:
1. Unregister from discovery - stop accepting new requests
1. Unregister from discovery - stop accepting new requests
...
@@ -167,43 +172,50 @@ class BaseWorkerHandler(BaseGenerativeHandler):
...
@@ -167,43 +172,50 @@ class BaseWorkerHandler(BaseGenerativeHandler):
ReleaseMemoryOccupationReqInput
,
ReleaseMemoryOccupationReqInput
,
)
)
tags
=
body
.
get
(
"tags"
,
body
.
get
(
"tag"
,
None
))
tags
=
list
(
DEFAULT_MEMORY_OCCUPATION_TAGS
)
if
tags
is
None
:
tokenizer_manager
=
(
tags
=
[
"kv_cache"
,
"weights"
,
"cuda_graph"
]
getattr
(
self
.
engine
,
"tokenizer_manager"
,
None
)
if
self
.
engine
is
not
None
else
None
)
if
tokenizer_manager
is
None
:
return
{
"status"
:
"error"
,
"message"
:
"memory control not supported on this worker"
,
}
async
with
self
.
_memory_occupation_lock
:
if
self
.
_memory_released
:
return
{
"status"
:
"ok"
,
"message"
:
"Memory already released"
,
}
try
:
# Step 1: Unregister endpoint from discovery FIRST
try
:
try
:
await
self
.
generate_endpoint
.
unregister_endpoint_instance
()
# Stop new requests and drain in-flight work before releasing memory.
except
Exception
as
unreg_err
:
if
self
.
generate_endpoint
is
not
None
:
logging
.
warning
(
await
self
.
generate_endpoint
.
unregister_endpoint_instance
()
f
"Failed to unregister endpoint from discovery:
{
unreg_err
}
"
)
# Step 2: Pause generation to drain in-flight requests
pause_req
=
PauseGenerationReqInput
()
pause_req
=
PauseGenerationReqInput
()
await
tokenizer_manager
.
pause_generation
(
pause_req
)
await
self
.
engine
.
tokenizer_manager
.
pause_generation
(
pause_req
)
# Step 3: Release memory now that it's safe
release_req
=
ReleaseMemoryOccupationReqInput
(
tags
=
tags
)
release_req
=
ReleaseMemoryOccupationReqInput
(
tags
=
tags
)
await
tokenizer_manager
.
release_memory_occupation
(
release_req
,
None
)
await
self
.
engine
.
tokenizer_manager
.
release_memory_occupation
(
self
.
_memory_released
=
True
release_req
,
None
)
return
{
return
{
"status"
:
"ok"
,
"status"
:
"ok"
,
"message"
:
f
"Memory released for tags:
{
tags
}
"
,
"message"
:
f
"Memory released for tags:
{
tags
}
"
,
}
}
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
f
"Failed to release memory occupation:
{
e
}
"
)
logging
.
error
(
f
"Failed to release memory occupation:
{
e
}
"
)
return
{
"status"
:
"error"
,
"message"
:
str
(
e
)}
return
{
"status"
:
"error"
,
"message"
:
str
(
e
)}
async
def
resume_memory_occupation
(
self
,
body
:
dict
)
->
dict
:
async
def
resume_memory_occupation
(
self
,
body
:
dict
)
->
dict
:
"""Resume GPU memory occupation and re-register to discovery.
"""Resume GPU memory occupation and re-register to discovery.
Args:
Args:
body: Dict with optional 'tags' key for which memory to resume.
body: Unused. Resume always targets default tags.
Default: ["kv_cache", "weights", "cuda_graph"]
Order of operations:
Order of operations:
1. Resume memory - restore GPU allocations
1. Resume memory - restore GPU allocations
...
@@ -215,36 +227,43 @@ class BaseWorkerHandler(BaseGenerativeHandler):
...
@@ -215,36 +227,43 @@ class BaseWorkerHandler(BaseGenerativeHandler):
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqInput
,
)
)
tags
=
body
.
get
(
"tags"
,
body
.
get
(
"tag"
,
None
))
tags
=
list
(
DEFAULT_MEMORY_OCCUPATION_TAGS
)
if
tags
is
None
:
tokenizer_manager
=
(
tags
=
[
"kv_cache"
,
"weights"
,
"cuda_graph"
]
getattr
(
self
.
engine
,
"tokenizer_manager"
,
None
)
if
self
.
engine
is
not
None
try
:
else
None
# Step 1: Resume memory first - must be ready before accepting requests
)
resume_req
=
ResumeMemoryOccupationReqInput
(
tags
=
tags
)
if
tokenizer_manager
is
None
:
await
self
.
engine
.
tokenizer_manager
.
resume_memory_occupation
(
return
{
resume_req
,
None
"status"
:
"error"
,
)
"message"
:
"memory control not supported on this worker"
,
}
# Step 2: Continue generation
async
with
self
.
_memory_occupation_lock
:
continue_req
=
ContinueGenerationReqInput
()
if
not
self
.
_memory_released
:
await
self
.
engine
.
tokenizer_manager
.
continue_generation
(
continue_req
)
return
{
"status"
:
"ok"
,
"message"
:
"Memory already resumed"
,
}
# Step 3: Re-register to discovery so frontend can route to us
try
:
try
:
await
self
.
generate_endpoint
.
register_endpoint_instance
()
resume_req
=
ResumeMemoryOccupationReqInput
(
tags
=
tags
)
except
Exception
as
reg_err
:
await
tokenizer_manager
.
resume_memory_occupation
(
resume_req
,
None
)
logging
.
warning
(
continue_req
=
ContinueGenerationReqInput
()
f
"Failed to re-register endpoint to discovery:
{
reg_err
}
"
await
tokenizer_manager
.
continue_generation
(
continue_req
)
)
if
self
.
generate_endpoint
is
not
None
:
return
{
await
self
.
generate_endpoint
.
register_endpoint_instance
()
"status"
:
"ok"
,
"message"
:
f
"Memory resumed for tags:
{
tags
}
"
,
self
.
_memory_released
=
False
}
except
Exception
as
e
:
return
{
logging
.
error
(
f
"Failed to resume memory occupation:
{
e
}
"
)
"status"
:
"ok"
,
return
{
"status"
:
"error"
,
"message"
:
str
(
e
)}
"message"
:
f
"Memory resumed for tags:
{
tags
}
"
,
}
except
Exception
as
e
:
logging
.
error
(
f
"Failed to resume memory occupation:
{
e
}
"
)
return
{
"status"
:
"error"
,
"message"
:
str
(
e
)}
async
def
start_profile
(
self
,
body
:
dict
)
->
dict
:
async
def
start_profile
(
self
,
body
:
dict
)
->
dict
:
"""Start profiling on the engine.
"""Start profiling on the engine.
...
...
components/src/dynamo/sglang/tests/test_sglang_memory_occupation_handlers.py
0 → 100644
View file @
c09a9aad
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
sys
import
types
from
types
import
SimpleNamespace
from
unittest.mock
import
AsyncMock
import
pytest
from
dynamo.sglang.request_handlers.handler_base
import
(
DEFAULT_MEMORY_OCCUPATION_TAGS
,
BaseWorkerHandler
,
)
pytestmark
=
[
pytest
.
mark
.
unit
,
pytest
.
mark
.
sglang
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
pre_merge
,
]
@
pytest
.
fixture
(
autouse
=
True
)
def
_stub_sglang_io_struct
(
monkeypatch
):
"""Keep unit tests independent from CUDA-only sglang imports."""
io_struct
=
types
.
ModuleType
(
"sglang.srt.managers.io_struct"
)
class
_Req
:
def
__init__
(
self
,
tags
=
None
):
self
.
tags
=
tags
io_struct
.
PauseGenerationReqInput
=
_Req
io_struct
.
ReleaseMemoryOccupationReqInput
=
_Req
io_struct
.
ResumeMemoryOccupationReqInput
=
_Req
io_struct
.
ContinueGenerationReqInput
=
_Req
monkeypatch
.
setitem
(
sys
.
modules
,
"sglang.srt.managers.io_struct"
,
io_struct
)
class
_TestWorkerHandler
(
BaseWorkerHandler
):
async
def
generate
(
self
,
request
,
context
):
yield
{}
def
_make_handler
()
->
_TestWorkerHandler
:
handler
=
_TestWorkerHandler
.
__new__
(
_TestWorkerHandler
)
handler
.
engine
=
SimpleNamespace
(
tokenizer_manager
=
SimpleNamespace
(
pause_generation
=
AsyncMock
(),
release_memory_occupation
=
AsyncMock
(),
resume_memory_occupation
=
AsyncMock
(),
continue_generation
=
AsyncMock
(),
)
)
handler
.
generate_endpoint
=
SimpleNamespace
(
unregister_endpoint_instance
=
AsyncMock
(),
register_endpoint_instance
=
AsyncMock
(),
)
handler
.
_memory_occupation_lock
=
asyncio
.
Lock
()
handler
.
_memory_released
=
False
return
handler
@
pytest
.
mark
.
asyncio
async
def
test_resume_before_release_is_noop
():
handler
=
_make_handler
()
result
=
await
handler
.
resume_memory_occupation
({})
assert
result
[
"status"
]
==
"ok"
assert
result
[
"message"
]
==
"Memory already resumed"
handler
.
engine
.
tokenizer_manager
.
resume_memory_occupation
.
assert_not_awaited
()
handler
.
engine
.
tokenizer_manager
.
continue_generation
.
assert_not_awaited
()
handler
.
generate_endpoint
.
register_endpoint_instance
.
assert_not_awaited
()
@
pytest
.
mark
.
asyncio
async
def
test_release_and_resume_are_idempotent
():
handler
=
_make_handler
()
first_release
=
await
handler
.
release_memory_occupation
({})
second_release
=
await
handler
.
release_memory_occupation
({})
first_resume
=
await
handler
.
resume_memory_occupation
({})
second_resume
=
await
handler
.
resume_memory_occupation
({})
assert
first_release
[
"status"
]
==
"ok"
assert
second_release
[
"status"
]
==
"ok"
assert
first_resume
[
"status"
]
==
"ok"
assert
second_resume
[
"status"
]
==
"ok"
assert
second_release
[
"message"
]
==
"Memory already released"
assert
second_resume
[
"message"
]
==
"Memory already resumed"
assert
DEFAULT_MEMORY_OCCUPATION_TAGS
==
[
"kv_cache"
,
"weights"
]
release_req
=
(
handler
.
engine
.
tokenizer_manager
.
release_memory_occupation
.
await_args
.
args
[
0
]
)
resume_req
=
(
handler
.
engine
.
tokenizer_manager
.
resume_memory_occupation
.
await_args
.
args
[
0
]
)
assert
release_req
.
tags
==
DEFAULT_MEMORY_OCCUPATION_TAGS
assert
resume_req
.
tags
==
DEFAULT_MEMORY_OCCUPATION_TAGS
handler
.
engine
.
tokenizer_manager
.
pause_generation
.
assert_awaited_once
()
handler
.
engine
.
tokenizer_manager
.
release_memory_occupation
.
assert_awaited_once
()
handler
.
generate_endpoint
.
unregister_endpoint_instance
.
assert_awaited_once
()
handler
.
engine
.
tokenizer_manager
.
resume_memory_occupation
.
assert_awaited_once
()
handler
.
engine
.
tokenizer_manager
.
continue_generation
.
assert_awaited_once
()
handler
.
generate_endpoint
.
register_endpoint_instance
.
assert_awaited_once
()
@
pytest
.
mark
.
asyncio
async
def
test_resume_uses_default_tags_even_when_request_specifies_subset
():
handler
=
_make_handler
()
await
handler
.
release_memory_occupation
({
"tags"
:
[
"weights"
]})
resume_result
=
await
handler
.
resume_memory_occupation
({
"tags"
:
[
"weights"
]})
assert
resume_result
[
"status"
]
==
"ok"
resume_req
=
(
handler
.
engine
.
tokenizer_manager
.
resume_memory_occupation
.
await_args
.
args
[
0
]
)
assert
resume_req
.
tags
==
DEFAULT_MEMORY_OCCUPATION_TAGS
handler
.
engine
.
tokenizer_manager
.
continue_generation
.
assert_awaited_once
()
handler
.
generate_endpoint
.
register_endpoint_instance
.
assert_awaited_once
()
@
pytest
.
mark
.
asyncio
async
def
test_resume_with_no_sleeping_state_is_noop
():
handler
=
_make_handler
()
result
=
await
handler
.
resume_memory_occupation
({})
assert
result
[
"status"
]
==
"ok"
assert
result
[
"message"
]
==
"Memory already resumed"
handler
.
engine
.
tokenizer_manager
.
resume_memory_occupation
.
assert_not_awaited
()
handler
.
engine
.
tokenizer_manager
.
continue_generation
.
assert_not_awaited
()
handler
.
generate_endpoint
.
register_endpoint_instance
.
assert_not_awaited
()
@
pytest
.
mark
.
asyncio
async
def
test_release_returns_error_when_worker_has_no_tokenizer_manager
():
handler
=
_make_handler
()
handler
.
engine
=
None
result
=
await
handler
.
release_memory_occupation
({})
assert
result
==
{
"status"
:
"error"
,
"message"
:
"memory control not supported on this worker"
,
}
handler
.
generate_endpoint
.
unregister_endpoint_instance
.
assert_not_awaited
()
@
pytest
.
mark
.
asyncio
async
def
test_resume_returns_error_when_worker_has_no_tokenizer_manager
():
handler
=
_make_handler
()
handler
.
engine
=
None
result
=
await
handler
.
resume_memory_occupation
({})
assert
result
==
{
"status"
:
"error"
,
"message"
:
"memory control not supported on this worker"
,
}
handler
.
generate_endpoint
.
register_endpoint_instance
.
assert_not_awaited
()
components/src/dynamo/vllm/handlers.py
View file @
c09a9aad
...
@@ -330,6 +330,8 @@ class BaseWorkerHandler(ABC):
...
@@ -330,6 +330,8 @@ class BaseWorkerHandler(ABC):
self
.
use_vllm_tokenizer
=
use_vllm_tokenizer
self
.
use_vllm_tokenizer
=
use_vllm_tokenizer
self
.
dp_range
=
get_dp_range_for_worker
(
self
.
engine_client
.
vllm_config
)
self
.
dp_range
=
get_dp_range_for_worker
(
self
.
engine_client
.
vllm_config
)
self
.
_sleep_wake_lock
=
asyncio
.
Lock
()
self
.
_engine_is_sleeping
=
False
# Initialize InputParamManager for text-in-text-out mode
# Initialize InputParamManager for text-in-text-out mode
tokenizer
=
None
tokenizer
=
None
...
@@ -351,64 +353,74 @@ class BaseWorkerHandler(ABC):
...
@@ -351,64 +353,74 @@ class BaseWorkerHandler(ABC):
2. Abort and drain in-flight requests
2. Abort and drain in-flight requests
3. Sleep engine - safe now that GPU is quiesced
3. Sleep engine - safe now that GPU is quiesced
"""
"""
body
=
body
or
{}
level
=
body
.
get
(
"level"
,
1
)
level
=
body
.
get
(
"level"
,
1
)
try
:
async
with
self
.
_sleep_wake_lock
:
# Step 1: Unregister endpoint instance FIRST to stop new requests from arriving
if
self
.
_engine_is_sleeping
:
return
{
"status"
:
"ok"
,
"message"
:
"Engine already sleeping"
,
}
try
:
try
:
await
self
.
generate_endpoint
.
unregister_endpoint_instance
()
# Step 1: Unregister endpoint instance before memory transitions.
logger
.
info
(
if
self
.
generate_endpoint
is
not
None
:
"[Sleep] Unregistered endpoint from discovery - worker removed from routing pool"
await
self
.
generate_endpoint
.
unregister_endpoint_instance
()
)
logger
.
info
(
except
Exception
as
unreg_err
:
"[Sleep] Unregistered endpoint from discovery - worker removed from routing pool"
logger
.
warning
(
)
f
"[Sleep] Failed to unregister endpoint from discovery:
{
unreg_err
}
"
)
# Step 2: Abort in-flight requests and wait for them to drain so the
# Step 2: Abort in-flight requests and wait for them to drain so the
# GPU is fully quiesced before unmapping memory.
# GPU is fully quiesced before unmapping memory.
await
self
.
engine_client
.
pause_generation
()
await
self
.
engine_client
.
pause_generation
()
# Step 3: Now safe to sleep - no in-flight GPU work
# Step 3: Now safe to sleep - no in-flight GPU work
await
self
.
engine_client
.
sleep
(
level
)
await
self
.
engine_client
.
sleep
(
level
)
self
.
_engine_is_sleeping
=
True
return
{
"status"
:
"ok"
,
"message"
:
f
"Engine slept (level=
{
level
}
)"
}
return
{
except
Exception
as
e
:
"status"
:
"ok"
,
logger
.
error
(
f
"Failed to sleep engine:
{
e
}
"
)
"message"
:
f
"Engine slept (level=
{
level
}
)"
,
return
{
"status"
:
"error"
,
"message"
:
str
(
e
)}
}
except
Exception
as
e
:
logger
.
error
(
f
"Failed to sleep engine:
{
e
}
"
)
return
{
"status"
:
"error"
,
"message"
:
str
(
e
)}
async
def
wake_up
(
self
,
body
:
dict
)
->
dict
:
async
def
wake_up
(
self
,
body
:
dict
)
->
dict
:
"""Wake the engine to restore GPU memory and re-register to discovery.
"""Wake the engine to restore GPU memory and re-register to discovery.
Args:
Args:
body:
Dict with optional 'tags' key (e.g., ["weights", "kv_cache"]). None wakes all
.
body:
Unused. Wake always restores all sleep-managed memory
.
Order of operations:
Order of operations:
1. Wake engine - restore GPU memory
1. Wake engine - restore GPU memory
2. Re-register endpoint instance - allow frontend to route requests here again
2. Re-register endpoint instance - allow frontend to route requests here again
"""
"""
tags
=
body
.
get
(
"tags"
)
async
with
self
.
_sleep_wake_lock
:
try
:
if
not
self
.
_engine_is_sleeping
:
# Step 1: Wake engine first - must be ready before accepting requests
return
{
"status"
:
"ok"
,
"message"
:
"Engine already awake"
}
await
self
.
engine_client
.
wake_up
(
tags
)
# Step 2: Resume generation so new requests can be processed
await
self
.
engine_client
.
resume_generation
()
# Step 3: Re-register endpoint instance to discovery so frontend can route to us again
try
:
try
:
await
self
.
generate_endpoint
.
register_endpoint_instance
()
# Step 1: Wake engine first - must be ready before accepting requests
logger
.
info
(
await
self
.
engine_client
.
wake_up
()
"[Wake] Re-registered endpoint to discovery - worker added back to routing pool"
)
except
Exception
as
reg_err
:
logger
.
warning
(
f
"[Wake] Failed to re-register endpoint to discovery:
{
reg_err
}
"
)
return
{
"status"
:
"ok"
,
"message"
:
f
"Engine woke (tags=
{
tags
}
)"
}
# Step 2: Resume generation and re-register.
except
Exception
as
e
:
await
self
.
engine_client
.
resume_generation
()
logger
.
error
(
f
"Failed to wake up engine:
{
e
}
"
)
if
self
.
generate_endpoint
is
not
None
:
return
{
"status"
:
"error"
,
"message"
:
str
(
e
)}
await
self
.
generate_endpoint
.
register_endpoint_instance
()
logger
.
info
(
"[Wake] Re-registered endpoint to discovery - worker added back to routing pool"
)
self
.
_engine_is_sleeping
=
False
return
{
"status"
:
"ok"
,
"message"
:
"Engine woke"
,
}
except
Exception
as
e
:
logger
.
error
(
f
"Failed to wake up engine:
{
e
}
"
)
return
{
"status"
:
"error"
,
"message"
:
str
(
e
)}
@
abstractmethod
@
abstractmethod
async
def
generate
(
self
,
request
,
context
)
->
AsyncGenerator
[
dict
,
None
]:
async
def
generate
(
self
,
request
,
context
)
->
AsyncGenerator
[
dict
,
None
]:
...
...
components/src/dynamo/vllm/tests/test_vllm_sleep_wake_handlers.py
0 → 100644
View file @
c09a9aad
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
asyncio
from
types
import
SimpleNamespace
from
unittest.mock
import
AsyncMock
import
pytest
from
dynamo.vllm.handlers
import
BaseWorkerHandler
pytestmark
=
[
pytest
.
mark
.
unit
,
pytest
.
mark
.
vllm
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
pre_merge
,
]
class
_TestWorkerHandler
(
BaseWorkerHandler
):
async
def
generate
(
self
,
request
,
context
):
yield
{}
def
_make_handler
()
->
_TestWorkerHandler
:
handler
=
_TestWorkerHandler
.
__new__
(
_TestWorkerHandler
)
handler
.
engine_client
=
SimpleNamespace
(
pause_generation
=
AsyncMock
(),
sleep
=
AsyncMock
(),
wake_up
=
AsyncMock
(),
resume_generation
=
AsyncMock
(),
)
handler
.
generate_endpoint
=
SimpleNamespace
(
unregister_endpoint_instance
=
AsyncMock
(),
register_endpoint_instance
=
AsyncMock
(),
)
handler
.
_sleep_wake_lock
=
asyncio
.
Lock
()
handler
.
_engine_is_sleeping
=
False
return
handler
@
pytest
.
mark
.
asyncio
async
def
test_wake_up_before_sleep_is_noop
():
handler
=
_make_handler
()
result
=
await
handler
.
wake_up
({})
assert
result
[
"status"
]
==
"ok"
handler
.
engine_client
.
wake_up
.
assert_not_awaited
()
handler
.
engine_client
.
resume_generation
.
assert_not_awaited
()
handler
.
generate_endpoint
.
register_endpoint_instance
.
assert_not_awaited
()
@
pytest
.
mark
.
asyncio
async
def
test_sleep_and_wake_are_idempotent
():
handler
=
_make_handler
()
first_sleep
=
await
handler
.
sleep
({
"level"
:
2
})
second_sleep
=
await
handler
.
sleep
({
"level"
:
2
})
first_wake
=
await
handler
.
wake_up
({})
second_wake
=
await
handler
.
wake_up
({})
assert
first_sleep
[
"status"
]
==
"ok"
assert
second_sleep
[
"status"
]
==
"ok"
assert
first_wake
[
"status"
]
==
"ok"
assert
second_wake
[
"status"
]
==
"ok"
handler
.
engine_client
.
pause_generation
.
assert_awaited_once
()
handler
.
engine_client
.
sleep
.
assert_awaited_once_with
(
2
)
handler
.
generate_endpoint
.
unregister_endpoint_instance
.
assert_awaited_once
()
handler
.
engine_client
.
wake_up
.
assert_awaited_once_with
()
handler
.
engine_client
.
resume_generation
.
assert_awaited_once
()
handler
.
generate_endpoint
.
register_endpoint_instance
.
assert_awaited_once
()
@
pytest
.
mark
.
asyncio
async
def
test_sleep_returns_error_for_unregister_failure
():
handler
=
_make_handler
()
handler
.
generate_endpoint
.
unregister_endpoint_instance
=
AsyncMock
(
side_effect
=
RuntimeError
(
"discovery backend down"
)
)
result
=
await
handler
.
sleep
({
"level"
:
1
})
assert
result
[
"status"
]
==
"error"
handler
.
engine_client
.
pause_generation
.
assert_not_awaited
()
handler
.
engine_client
.
sleep
.
assert_not_awaited
()
@
pytest
.
mark
.
asyncio
async
def
test_wake_up_returns_error_for_register_failure
():
handler
=
_make_handler
()
handler
.
_engine_is_sleeping
=
True
handler
.
generate_endpoint
.
register_endpoint_instance
=
AsyncMock
(
side_effect
=
RuntimeError
(
"discovery write timeout"
)
)
result
=
await
handler
.
wake_up
({})
assert
result
[
"status"
]
==
"error"
handler
.
engine_client
.
wake_up
.
assert_awaited_once_with
()
handler
.
engine_client
.
resume_generation
.
assert_awaited_once
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment