Unverified Commit 03dd785c authored by Steven Shimizu's avatar Steven Shimizu Committed by GitHub
Browse files

Added async_encode method to Engine (#4701)

parent 66fc63d6
...@@ -285,6 +285,21 @@ class Engine(EngineBase): ...@@ -285,6 +285,21 @@ class Engine(EngineBase):
ret = loop.run_until_complete(generator.__anext__()) ret = loop.run_until_complete(generator.__anext__())
return ret return ret
async def async_encode(
self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
image_data: Optional[Union[List[str], str]] = None,
) -> Dict:
"""
Asynchronous version of encode method.
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
Please refer to `EmbeddingReqInput` for the documentation.
"""
obj = EmbeddingReqInput(text=prompt, image_data=image_data)
generator = self.tokenizer_manager.generate_request(obj, None)
return await generator.__anext__()
def shutdown(self): def shutdown(self):
"""Shutdown the engine""" """Shutdown the engine"""
kill_process_tree(os.getpid(), include_parent=False) kill_process_tree(os.getpid(), include_parent=False)
......
...@@ -185,6 +185,35 @@ class TestSRTEngine(CustomTestCase): ...@@ -185,6 +185,35 @@ class TestSRTEngine(CustomTestCase):
result = throughput_test(server_args=server_args, bench_args=bench_args) result = throughput_test(server_args=server_args, bench_args=bench_args)
self.assertGreater(result["total_throughput"], 3000) self.assertGreater(result["total_throughput"], 3000)
def test_8_engine_async_encode_consistency(self):
prompt = "Today is a sunny day and I like"
model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
engine = sgl.Engine(
model_path=model_path,
is_embedding=True,
random_seed=42,
disable_radix_cache=True,
)
# Get sync and async embeddings
out1 = torch.tensor(engine.encode(prompt)["embedding"])
loop = asyncio.get_event_loop()
out2 = torch.tensor(
loop.run_until_complete(engine.async_encode(prompt))["embedding"]
)
engine.shutdown()
print("\n==== Shapes ====")
print(f"sync shape: {out1.shape}")
print(f"async shape: {out2.shape}")
self.assertTrue(
torch.allclose(out1, out2, atol=1e-5, rtol=1e-3),
"Sync and async embeddings are not equal within tolerance",
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment