Commit 1c7f32cd authored by zhuwenwen's avatar zhuwenwen
Browse files

[fix]v0 SamplerOutput在非tree decoding时不传入logits

parent ad038b4e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A layer that samples the next tokens from the model's outputs."""
import os
import itertools
from collections.abc import Iterator
from dataclasses import dataclass
......@@ -204,6 +205,7 @@ class Sampler(nn.Module):
# speculative decoding and when prompt embeddings are specified.
self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def _init_sampling_tensors(
self,
......@@ -341,7 +343,7 @@ class Sampler(nn.Module):
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output,
logits=logits)
logits=logits if self.tree_decoding else None)
@property
def _should_modify_greedy_probs_inplace(self) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment