"vscode:/vscode.git/clone" did not exist on "cb080f32e38e87beda897d0602bf6a0d0c79d00f"
Commit d8297312 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-wm' into 'v0.9.2-dev'

[fix]v0 SamplerOutput在非tree decoding时不传入logits

See merge request dcutoolkit/deeplearing/vllm!185
parents 9a4a94ee 01692be6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A layer that samples the next tokens from the model's outputs."""
import os
import itertools
from collections.abc import Iterator
from dataclasses import dataclass
......@@ -210,6 +211,7 @@ class Sampler(nn.Module):
# speculative decoding and when prompt embeddings are specified.
self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def _init_sampling_tensors(
self,
......@@ -347,7 +349,7 @@ class Sampler(nn.Module):
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output,
logits=logits)
logits=logits if self.tree_decoding else None)
@property
def _should_modify_greedy_probs_inplace(self) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment