import sys from vllm.inputs.data import TokensPrompt as _OriginalTokensPrompt from vllm.model_executor.layers.rotary_embedding import ( MRotaryEmbedding as _OriginalMRotaryEmbedding, ) from vllm.v1.engine import EngineCoreOutput as _OriginalEngineCoreOutput from vllm.v1.engine import EngineCoreOutputs as _OriginalEngineCoreOutputs from vllm.v1.engine import EngineCoreRequest as _OriginalEngineCoreRequest from vllm.v1.request import Request as _OriginalRequest import vllm_omni.logger # noqa: F401 from vllm_omni.engine import OmniEngineCoreOutput, OmniEngineCoreOutputs, OmniEngineCoreRequest from vllm_omni.inputs.data import OmniTokensPrompt from vllm_omni.model_executor.layers.rotary_embedding import OmniMRotaryEmbedding from vllm_omni.request import OmniRequest for module_name, module in sys.modules.items(): # only do patch on module of vllm, pass others if "vllm" not in module_name: continue if hasattr(module, "EngineCoreOutput") and module.EngineCoreOutput == _OriginalEngineCoreOutput: module.EngineCoreOutput = OmniEngineCoreOutput if hasattr(module, "EngineCoreOutputs") and module.EngineCoreOutputs == _OriginalEngineCoreOutputs: module.EngineCoreOutputs = OmniEngineCoreOutputs if hasattr(module, "TokensPrompt") and module.TokensPrompt == _OriginalTokensPrompt: module.TokensPrompt = OmniTokensPrompt if hasattr(module, "MRotaryEmbedding") and module.MRotaryEmbedding == _OriginalMRotaryEmbedding: module.MRotaryEmbedding = OmniMRotaryEmbedding if hasattr(module, "Request") and module.Request == _OriginalRequest: module.Request = OmniRequest if hasattr(module, "EngineCoreRequest") and module.EngineCoreRequest == _OriginalEngineCoreRequest: module.EngineCoreRequest = OmniEngineCoreRequest