Unverified Commit ce11dd82 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[CI] Try fix broken event loop init (#11746)

parent 9e87b60f
...@@ -30,8 +30,6 @@ import time ...@@ -30,8 +30,6 @@ import time
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
import zmq import zmq
import zmq.asyncio
from PIL.Image import Image
from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info
...@@ -147,6 +145,12 @@ class Engine(EngineBase): ...@@ -147,6 +145,12 @@ class Engine(EngineBase):
thread_label = "Tokenizer" thread_label = "Tokenizer"
trace_set_thread_info(thread_label) trace_set_thread_info(thread_label)
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def generate( def generate(
self, self,
# The input prompt. It can be a single prompt or a batch of prompts. # The input prompt. It can be a single prompt or a batch of prompts.
...@@ -210,7 +214,6 @@ class Engine(EngineBase): ...@@ -210,7 +214,6 @@ class Engine(EngineBase):
bootstrap_room=bootstrap_room, bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
) )
loop = asyncio.get_event_loop()
generator = self.tokenizer_manager.generate_request(obj, None) generator = self.tokenizer_manager.generate_request(obj, None)
if stream: if stream:
...@@ -218,14 +221,14 @@ class Engine(EngineBase): ...@@ -218,14 +221,14 @@ class Engine(EngineBase):
def generator_wrapper(): def generator_wrapper():
while True: while True:
try: try:
chunk = loop.run_until_complete(generator.__anext__()) chunk = self.loop.run_until_complete(generator.__anext__())
yield chunk yield chunk
except StopAsyncIteration: except StopAsyncIteration:
break break
return generator_wrapper() return generator_wrapper()
else: else:
ret = loop.run_until_complete(generator.__anext__()) ret = self.loop.run_until_complete(generator.__anext__())
return ret return ret
async def async_generate( async def async_generate(
...@@ -317,9 +320,8 @@ class Engine(EngineBase): ...@@ -317,9 +320,8 @@ class Engine(EngineBase):
audio_data=audio_data, audio_data=audio_data,
video_data=video_data, video_data=video_data,
) )
loop = asyncio.get_event_loop()
generator = self.tokenizer_manager.generate_request(obj, None) generator = self.tokenizer_manager.generate_request(obj, None)
ret = loop.run_until_complete(generator.__anext__()) ret = self.loop.run_until_complete(generator.__anext__())
return ret return ret
async def async_encode( async def async_encode(
...@@ -353,9 +355,8 @@ class Engine(EngineBase): ...@@ -353,9 +355,8 @@ class Engine(EngineBase):
Please refer to `EmbeddingReqInput` for the documentation. Please refer to `EmbeddingReqInput` for the documentation.
""" """
obj = EmbeddingReqInput(text=prompt, is_cross_encoder_request=True) obj = EmbeddingReqInput(text=prompt, is_cross_encoder_request=True)
loop = asyncio.get_event_loop()
generator = self.tokenizer_manager.generate_request(obj, None) generator = self.tokenizer_manager.generate_request(obj, None)
ret = loop.run_until_complete(generator.__anext__()) ret = self.loop.run_until_complete(generator.__anext__())
return ret return ret
def shutdown(self): def shutdown(self):
...@@ -370,38 +371,31 @@ class Engine(EngineBase): ...@@ -370,38 +371,31 @@ class Engine(EngineBase):
return False return False
def flush_cache(self): def flush_cache(self):
loop = asyncio.get_event_loop() return self.loop.run_until_complete(self.tokenizer_manager.flush_cache())
return loop.run_until_complete(self.tokenizer_manager.flush_cache())
def start_profile(self, **kwargs): def start_profile(self, **kwargs):
loop = asyncio.get_event_loop() self.loop.run_until_complete(self.tokenizer_manager.start_profile(**kwargs))
loop.run_until_complete(self.tokenizer_manager.start_profile(**kwargs))
def stop_profile(self): def stop_profile(self):
loop = asyncio.get_event_loop() self.loop.run_until_complete(self.tokenizer_manager.stop_profile())
loop.run_until_complete(self.tokenizer_manager.stop_profile())
def start_expert_distribution_record(self): def start_expert_distribution_record(self):
loop = asyncio.get_event_loop() self.loop.run_until_complete(
loop.run_until_complete(
self.tokenizer_manager.start_expert_distribution_record() self.tokenizer_manager.start_expert_distribution_record()
) )
def stop_expert_distribution_record(self): def stop_expert_distribution_record(self):
loop = asyncio.get_event_loop() self.loop.run_until_complete(
loop.run_until_complete(
self.tokenizer_manager.stop_expert_distribution_record() self.tokenizer_manager.stop_expert_distribution_record()
) )
def dump_expert_distribution_record(self): def dump_expert_distribution_record(self):
loop = asyncio.get_event_loop() self.loop.run_until_complete(
loop.run_until_complete(
self.tokenizer_manager.dump_expert_distribution_record() self.tokenizer_manager.dump_expert_distribution_record()
) )
def get_server_info(self): def get_server_info(self):
loop = asyncio.get_event_loop() internal_states = self.loop.run_until_complete(
internal_states = loop.run_until_complete(
self.tokenizer_manager.get_internal_state() self.tokenizer_manager.get_internal_state()
) )
return { return {
...@@ -429,8 +423,7 @@ class Engine(EngineBase): ...@@ -429,8 +423,7 @@ class Engine(EngineBase):
group_name=group_name, group_name=group_name,
backend=backend, backend=backend,
) )
loop = asyncio.get_event_loop() return self.loop.run_until_complete(
return loop.run_until_complete(
self.tokenizer_manager.init_weights_update_group(obj, None) self.tokenizer_manager.init_weights_update_group(obj, None)
) )
...@@ -442,8 +435,7 @@ class Engine(EngineBase): ...@@ -442,8 +435,7 @@ class Engine(EngineBase):
obj = DestroyWeightsUpdateGroupReqInput( obj = DestroyWeightsUpdateGroupReqInput(
group_name=group_name, group_name=group_name,
) )
loop = asyncio.get_event_loop() return self.loop.run_until_complete(
return loop.run_until_complete(
self.tokenizer_manager.destroy_weights_update_group(obj, None) self.tokenizer_manager.destroy_weights_update_group(obj, None)
) )
...@@ -463,8 +455,7 @@ class Engine(EngineBase): ...@@ -463,8 +455,7 @@ class Engine(EngineBase):
group_name=group_name, group_name=group_name,
flush_cache=flush_cache, flush_cache=flush_cache,
) )
loop = asyncio.get_event_loop() return self.loop.run_until_complete(
return loop.run_until_complete(
self.tokenizer_manager.update_weights_from_distributed(obj, None) self.tokenizer_manager.update_weights_from_distributed(obj, None)
) )
...@@ -488,9 +479,7 @@ class Engine(EngineBase): ...@@ -488,9 +479,7 @@ class Engine(EngineBase):
load_format=load_format, load_format=load_format,
flush_cache=flush_cache, flush_cache=flush_cache,
) )
loop = asyncio.get_event_loop() return self.loop.run_until_complete(
return loop.run_until_complete(
self.tokenizer_manager.update_weights_from_tensor(obj, None) self.tokenizer_manager.update_weights_from_tensor(obj, None)
) )
...@@ -510,16 +499,14 @@ class Engine(EngineBase): ...@@ -510,16 +499,14 @@ class Engine(EngineBase):
load_format=load_format, load_format=load_format,
) )
loop = asyncio.get_event_loop() return self.loop.run_until_complete(
return loop.run_until_complete(
self.tokenizer_manager.update_weights_from_disk(obj, None) self.tokenizer_manager.update_weights_from_disk(obj, None)
) )
def get_weights_by_name(self, name: str, truncate_size: int = 100): def get_weights_by_name(self, name: str, truncate_size: int = 100):
"""Get weights by parameter name.""" """Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
loop = asyncio.get_event_loop() return self.loop.run_until_complete(
return loop.run_until_complete(
self.tokenizer_manager.get_weights_by_name(obj, None) self.tokenizer_manager.get_weights_by_name(obj, None)
) )
...@@ -532,8 +519,7 @@ class Engine(EngineBase): ...@@ -532,8 +519,7 @@ class Engine(EngineBase):
pinned=pinned, pinned=pinned,
) )
loop = asyncio.get_event_loop() return self.loop.run_until_complete(
return loop.run_until_complete(
self.tokenizer_manager.load_lora_adapter(obj, None) self.tokenizer_manager.load_lora_adapter(obj, None)
) )
...@@ -542,22 +528,19 @@ class Engine(EngineBase): ...@@ -542,22 +528,19 @@ class Engine(EngineBase):
obj = UnloadLoRAAdapterReqInput(lora_name=lora_name) obj = UnloadLoRAAdapterReqInput(lora_name=lora_name)
loop = asyncio.get_event_loop() return self.loop.run_until_complete(
return loop.run_until_complete(
self.tokenizer_manager.unload_lora_adapter(obj, None) self.tokenizer_manager.unload_lora_adapter(obj, None)
) )
def release_memory_occupation(self, tags: Optional[List[str]] = None): def release_memory_occupation(self, tags: Optional[List[str]] = None):
obj = ReleaseMemoryOccupationReqInput(tags=tags) obj = ReleaseMemoryOccupationReqInput(tags=tags)
loop = asyncio.get_event_loop() return self.loop.run_until_complete(
return loop.run_until_complete(
self.tokenizer_manager.release_memory_occupation(obj, None) self.tokenizer_manager.release_memory_occupation(obj, None)
) )
def resume_memory_occupation(self, tags: Optional[List[str]] = None): def resume_memory_occupation(self, tags: Optional[List[str]] = None):
obj = ResumeMemoryOccupationReqInput(tags=tags) obj = ResumeMemoryOccupationReqInput(tags=tags)
loop = asyncio.get_event_loop() return self.loop.run_until_complete(
return loop.run_until_complete(
self.tokenizer_manager.resume_memory_occupation(obj, None) self.tokenizer_manager.resume_memory_occupation(obj, None)
) )
...@@ -574,8 +557,7 @@ class Engine(EngineBase): ...@@ -574,8 +557,7 @@ class Engine(EngineBase):
collection. collection.
""" """
loop = asyncio.get_event_loop() self.loop.run_until_complete(self.tokenizer_manager.freeze_gc())
loop.run_until_complete(self.tokenizer_manager.freeze_gc())
""" """
Execute an RPC call on all scheduler processes. Execute an RPC call on all scheduler processes.
...@@ -633,8 +615,7 @@ class Engine(EngineBase): ...@@ -633,8 +615,7 @@ class Engine(EngineBase):
ValueError: If query is not provided, or if items is not provided, ValueError: If query is not provided, or if items is not provided,
or if token IDs are out of vocabulary, or if logprobs are not available for the specified tokens. or if token IDs are out of vocabulary, or if logprobs are not available for the specified tokens.
""" """
loop = asyncio.get_event_loop() return self.loop.run_until_complete(
return loop.run_until_complete(
self.tokenizer_manager.score_request( self.tokenizer_manager.score_request(
query=query, query=query,
items=items, items=items,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment