Unverified Commit ce11dd82 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[CI] Try fix broken event loop init (#11746)

parent 9e87b60f
......@@ -30,8 +30,6 @@ import time
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
import zmq
import zmq.asyncio
from PIL.Image import Image
from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info
......@@ -147,6 +145,12 @@ class Engine(EngineBase):
thread_label = "Tokenizer"
trace_set_thread_info(thread_label)
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
......@@ -210,7 +214,6 @@ class Engine(EngineBase):
bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank,
)
loop = asyncio.get_event_loop()
generator = self.tokenizer_manager.generate_request(obj, None)
if stream:
......@@ -218,14 +221,14 @@ class Engine(EngineBase):
def generator_wrapper():
while True:
try:
chunk = loop.run_until_complete(generator.__anext__())
chunk = self.loop.run_until_complete(generator.__anext__())
yield chunk
except StopAsyncIteration:
break
return generator_wrapper()
else:
ret = loop.run_until_complete(generator.__anext__())
ret = self.loop.run_until_complete(generator.__anext__())
return ret
async def async_generate(
......@@ -317,9 +320,8 @@ class Engine(EngineBase):
audio_data=audio_data,
video_data=video_data,
)
loop = asyncio.get_event_loop()
generator = self.tokenizer_manager.generate_request(obj, None)
ret = loop.run_until_complete(generator.__anext__())
ret = self.loop.run_until_complete(generator.__anext__())
return ret
async def async_encode(
......@@ -353,9 +355,8 @@ class Engine(EngineBase):
Please refer to `EmbeddingReqInput` for the documentation.
"""
obj = EmbeddingReqInput(text=prompt, is_cross_encoder_request=True)
loop = asyncio.get_event_loop()
generator = self.tokenizer_manager.generate_request(obj, None)
ret = loop.run_until_complete(generator.__anext__())
ret = self.loop.run_until_complete(generator.__anext__())
return ret
def shutdown(self):
......@@ -370,38 +371,31 @@ class Engine(EngineBase):
return False
def flush_cache(self):
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.tokenizer_manager.flush_cache())
return self.loop.run_until_complete(self.tokenizer_manager.flush_cache())
def start_profile(self, **kwargs):
loop = asyncio.get_event_loop()
loop.run_until_complete(self.tokenizer_manager.start_profile(**kwargs))
self.loop.run_until_complete(self.tokenizer_manager.start_profile(**kwargs))
def stop_profile(self):
loop = asyncio.get_event_loop()
loop.run_until_complete(self.tokenizer_manager.stop_profile())
self.loop.run_until_complete(self.tokenizer_manager.stop_profile())
def start_expert_distribution_record(self):
loop = asyncio.get_event_loop()
loop.run_until_complete(
self.loop.run_until_complete(
self.tokenizer_manager.start_expert_distribution_record()
)
def stop_expert_distribution_record(self):
loop = asyncio.get_event_loop()
loop.run_until_complete(
self.loop.run_until_complete(
self.tokenizer_manager.stop_expert_distribution_record()
)
def dump_expert_distribution_record(self):
loop = asyncio.get_event_loop()
loop.run_until_complete(
self.loop.run_until_complete(
self.tokenizer_manager.dump_expert_distribution_record()
)
def get_server_info(self):
loop = asyncio.get_event_loop()
internal_states = loop.run_until_complete(
internal_states = self.loop.run_until_complete(
self.tokenizer_manager.get_internal_state()
)
return {
......@@ -429,8 +423,7 @@ class Engine(EngineBase):
group_name=group_name,
backend=backend,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
return self.loop.run_until_complete(
self.tokenizer_manager.init_weights_update_group(obj, None)
)
......@@ -442,8 +435,7 @@ class Engine(EngineBase):
obj = DestroyWeightsUpdateGroupReqInput(
group_name=group_name,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
return self.loop.run_until_complete(
self.tokenizer_manager.destroy_weights_update_group(obj, None)
)
......@@ -463,8 +455,7 @@ class Engine(EngineBase):
group_name=group_name,
flush_cache=flush_cache,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
return self.loop.run_until_complete(
self.tokenizer_manager.update_weights_from_distributed(obj, None)
)
......@@ -488,9 +479,7 @@ class Engine(EngineBase):
load_format=load_format,
flush_cache=flush_cache,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
return self.loop.run_until_complete(
self.tokenizer_manager.update_weights_from_tensor(obj, None)
)
......@@ -510,16 +499,14 @@ class Engine(EngineBase):
load_format=load_format,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
return self.loop.run_until_complete(
self.tokenizer_manager.update_weights_from_disk(obj, None)
)
def get_weights_by_name(self, name: str, truncate_size: int = 100):
"""Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
return self.loop.run_until_complete(
self.tokenizer_manager.get_weights_by_name(obj, None)
)
......@@ -532,8 +519,7 @@ class Engine(EngineBase):
pinned=pinned,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
return self.loop.run_until_complete(
self.tokenizer_manager.load_lora_adapter(obj, None)
)
......@@ -542,22 +528,19 @@ class Engine(EngineBase):
obj = UnloadLoRAAdapterReqInput(lora_name=lora_name)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
return self.loop.run_until_complete(
self.tokenizer_manager.unload_lora_adapter(obj, None)
)
def release_memory_occupation(self, tags: Optional[List[str]] = None):
obj = ReleaseMemoryOccupationReqInput(tags=tags)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
return self.loop.run_until_complete(
self.tokenizer_manager.release_memory_occupation(obj, None)
)
def resume_memory_occupation(self, tags: Optional[List[str]] = None):
obj = ResumeMemoryOccupationReqInput(tags=tags)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
return self.loop.run_until_complete(
self.tokenizer_manager.resume_memory_occupation(obj, None)
)
......@@ -574,8 +557,7 @@ class Engine(EngineBase):
collection.
"""
loop = asyncio.get_event_loop()
loop.run_until_complete(self.tokenizer_manager.freeze_gc())
self.loop.run_until_complete(self.tokenizer_manager.freeze_gc())
"""
Execute an RPC call on all scheduler processes.
......@@ -633,8 +615,7 @@ class Engine(EngineBase):
ValueError: If query is not provided, or if items is not provided,
or if token IDs are out of vocabulary, or if logprobs are not available for the specified tokens.
"""
loop = asyncio.get_event_loop()
return loop.run_until_complete(
return self.loop.run_until_complete(
self.tokenizer_manager.score_request(
query=query,
items=items,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment