from contextlib import contextmanager from types import SimpleNamespace import torch from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner class DummyBuffer: """A minimal buffer wrapper that exposes the `.gpu` attribute.""" def __init__(self, t: torch.Tensor): self.gpu = t class DummyInputBatch: """A minimal input batch that only provides `req_ids`.""" def __init__(self, req_ids): self.req_ids = req_ids class DummyReqState: """A minimal request state container.""" pass class DummyTalkerMTP(torch.nn.Module): """A fake talker_mtp module for deterministic CPU testing.""" def forward(self, req_input_ids, req_embeds, last_talker_hidden, text_step): # Deterministic behavior: # - output embeds = input embeds + 1 # - output codes = [[0], [1], ...] bsz = req_embeds.shape[0] new_embeds = req_embeds + 1.0 codes = torch.arange(bsz, dtype=torch.int64).view(bsz, 1) return new_embeds, codes @contextmanager def _noop_forward_context(*args, **kwargs): """A no-op context manager to replace vLLM forward context in CPU tests.""" yield def _make_runner(req_ids=("r1", "r2"), hidden_size=4): # Create an instance without calling OmniGPUModelRunner.__init__ runner = object.__new__(OmniGPUModelRunner) # Minimal attributes used by OmniGPUModelRunner._talker_mtp_forward runner.input_batch = DummyInputBatch(list(req_ids)) runner.requests = {rid: DummyReqState() for rid in req_ids} # query_start_loc.cpu[req_index] is used to locate the token position # in the flattened `inputs_embeds`. runner.query_start_loc = type("QSL", (), {})() # Map: r1 -> offset 0, r2 -> offset 3 runner.query_start_loc.cpu = torch.tensor([0, 3], dtype=torch.int32) bsz = len(req_ids) runner.talker_mtp_input_ids = DummyBuffer(torch.zeros((bsz,), dtype=torch.int64)) runner.talker_mtp_inputs_embeds = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) runner.last_talker_hidden = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) runner.text_step = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) runner.talker_mtp = DummyTalkerMTP() runner.vllm_config = object() # Provide a minimal implementation that returns the expected 4-tuple. def _determine_batch_execution_and_padding(**kwargs): return None, object(), None, None, None runner._determine_batch_execution_and_padding = _determine_batch_execution_and_padding # Use the real merge method from OmniGPUModelRunner. return runner def test_talker_mtp_forward_cpu_updates_inputs_and_info(monkeypatch): # Patch the module-level `set_forward_context` symbol used inside # OmniGPUModelRunner._talker_mtp_forward. import vllm_omni.worker.gpu_model_runner as mod # Must be the same module that defines OmniGPUModelRunner monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context) runner = _make_runner(req_ids=("r1", "r2"), hidden_size=4) def fake_determine(self, num_tokens, num_reqs, num_scheduled_tokens_np, max_num_scheduled_tokens, use_cascade_attn): batch_desc = SimpleNamespace(num_tokens=int(num_tokens)) return (False, batch_desc, None, None, None) monkeypatch.setattr(runner, "_determine_batch_execution_and_padding", fake_determine.__get__(runner, type(runner))) # Initialize per-request embeds (batch-major inside talker_mtp_inputs_embeds) runner.talker_mtp_inputs_embeds.gpu[0] = torch.tensor([1.0, 2.0, 3.0, 4.0]) runner.talker_mtp_inputs_embeds.gpu[1] = torch.tensor([10.0, 20.0, 30.0, 40.0]) # Flattened `inputs_embeds`: offsets 0 and 3 will be overwritten inputs_embeds = torch.zeros((6, 4), dtype=torch.float32) # Call the original implementation from OmniGPUModelRunner (no re-implementation) OmniGPUModelRunner._talker_mtp_forward(runner, ["r1", "r2"], inputs_embeds) # Validate embeds were written back (+1) assert torch.allclose(inputs_embeds[0], torch.tensor([2.0, 3.0, 4.0, 5.0])) assert torch.allclose(inputs_embeds[3], torch.tensor([11.0, 21.0, 31.0, 41.0])) # Validate per-request additional_information_cpu was updated info_r1 = runner.requests["r1"].additional_information_cpu info_r2 = runner.requests["r2"].additional_information_cpu assert int(info_r1["code_predictor_codes"][0, 0]) == 0 assert int(info_r2["code_predictor_codes"][0, 0]) == 1 def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch): import vllm_omni.worker.gpu_model_runner as mod monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context) runner = _make_runner(req_ids=("r1",), hidden_size=4) inputs_embeds = torch.randn((2, 4)) before = inputs_embeds.clone() OmniGPUModelRunner._talker_mtp_forward(runner, [], inputs_embeds) # Ensure no changes were made assert torch.allclose(inputs_embeds, before)