Unverified Commit 736ed388 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[CI/Build] Fix Args for `_get_logits_warper` in Sampler Test (#5922)

parent 365791ff
......@@ -587,7 +587,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
generation_config = GenerationConfig(top_k=top_k,
top_p=top_p,
do_sample=True)
warpers = generation_model._get_logits_warper(generation_config)
warpers = generation_model._get_logits_warper(generation_config, device)
assert len(warpers) == 2 # top_p and top_k
seq_group_metadata_list: List[SequenceGroupMetadata] = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment