Commit 01692be6 authored by 王敏's avatar 王敏
Browse files

[fix]v0 SamplerOutput在非tree decoding时不传入logits

parent 9a4a94ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A layer that samples the next tokens from the model's outputs."""
import os
import itertools
from collections.abc import Iterator
from dataclasses import dataclass
......@@ -210,6 +211,7 @@ class Sampler(nn.Module):
# speculative decoding and when prompt embeddings are specified.
self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def _init_sampling_tensors(
self,
......@@ -347,7 +349,7 @@ class Sampler(nn.Module):
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output,
logits=logits)
logits=logits if self.tree_decoding else None)
@property
def _should_modify_greedy_probs_inplace(self) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment