"torchvision/vscode:/vscode.git/clone" did not exist on "cac4e228c9ca9e7564cb34406e7ebccfdd736976"
Unverified Commit d3a5bd9f authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Fix sampler test (#1379)

parent e8ef4c08
# pylint: disable=protected-access
import pytest
import random
from typing import Tuple
from unittest.mock import patch
import pytest
import torch
from vllm.model_executor.layers.sampler import Sampler
......@@ -69,7 +69,7 @@ def test_sampler_all_greedy(seed: int):
input_metadata=input_metadata)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
for nth_output in sequence_output.samples:
assert nth_output.output_token == expected[i].item()
......@@ -101,7 +101,7 @@ def test_sampler_all_random(seed: int):
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
for nth_output in sequence_output.samples:
assert nth_output.output_token == i
......@@ -181,5 +181,5 @@ def test_sampler_mixed(seed: int):
for i, sequence_output in enumerate(sampler_output):
if seq_group_metadata_list[i].sampling_params.use_beam_search:
continue
for nth_output in sequence_output:
for nth_output in sequence_output.samples:
assert nth_output.output_token in expected_tokens
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment