Commit d8297312 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-wm' into 'v0.9.2-dev'

[fix]v0 SamplerOutput在非tree decoding时不传入logits

See merge request dcutoolkit/deeplearing/vllm!185
parents 9a4a94ee 01692be6
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
import os
import itertools import itertools
from collections.abc import Iterator from collections.abc import Iterator
from dataclasses import dataclass from dataclasses import dataclass
...@@ -210,6 +211,7 @@ class Sampler(nn.Module): ...@@ -210,6 +211,7 @@ class Sampler(nn.Module):
# speculative decoding and when prompt embeddings are specified. # speculative decoding and when prompt embeddings are specified.
self.include_gpu_probs_tensor = False self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False self.should_modify_greedy_probs_inplace = False
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def _init_sampling_tensors( def _init_sampling_tensors(
self, self,
...@@ -347,7 +349,7 @@ class Sampler(nn.Module): ...@@ -347,7 +349,7 @@ class Sampler(nn.Module):
sample_logprobs, sample_logprobs,
on_device_tensors=on_device_tensors, on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output, skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output,
logits=logits) logits=logits if self.tree_decoding else None)
@property @property
def _should_modify_greedy_probs_inplace(self) -> bool: def _should_modify_greedy_probs_inplace(self) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment