Commit ffa31925 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.7.2-dev' into v0.7.2-fusion

parents 9e813a0e a9267f52
......@@ -38,7 +38,7 @@ from transformers import PreTrainedTokenizerBase
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser
import triton
# import triton
try:
......@@ -46,23 +46,23 @@ try:
except ImportError:
from backend_request_func import get_tokenizer
triton_version = triton.__version__
if triton_version.startswith("2.1"):
from triton.common.backend import compute_core_version_key
elif triton_version.startswith("3.0"):
from triton.compiler.compiler import triton_key
else:
print(f"TRITON version {triton_version} is not specifically handled.")
# triton_version = triton.__version__
# if triton_version.startswith("2.1"):
# from triton.common.backend import compute_core_version_key
# elif triton_version.startswith("3.0"):
# from triton.compiler.compiler import triton_key
# else:
# print(f"TRITON version {triton_version} is not specifically handled.")
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
def test_prefix(llm=None, sampling_params=None, prompts=None):
if triton_version.startswith("2.1"):
version_key = compute_core_version_key()
if triton_version.startswith("3.0"):
version_key = triton_key()
# if triton_version.startswith("2.1"):
# version_key = compute_core_version_key()
# if triton_version.startswith("3.0"):
# version_key = triton_key()
start_time = time.time()
llm.generate(prompts, sampling_params=sampling_params)
......
......@@ -488,11 +488,11 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None:
sha = get_sha(vllm_root)
if (major, minor) == ('2', '4'):
version = 'das.opt2.alpha.' + sha[:7]
version = 'das.opt1.' + sha[:7]
# version = 'das.opt1.' + sha[:7]
else:
if (major, minor) == ('2', '4'):
version = 'das.opt2.alpha'
version = 'das.opt1'
# version = 'das.opt1'
......
......@@ -5,6 +5,8 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch
import triton
from triton.compiler.compiler import triton_key
import vllm.envs as envs
from vllm import _custom_ops as ops
......@@ -30,7 +32,7 @@ _ON_NAVI = "gfx1" in _GPU_ARCH
_ON_MI250_MI300 = any(arch in _GPU_ARCH
for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"])
class ROCmFlashAttentionBackend(AttentionBackend):
@staticmethod
......@@ -778,6 +780,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else:
# prefix-enabled attention
# not applicable for encoder-only models
version_key = triton_key()
if self.attn_type != AttentionType.ENCODER_ONLY:
output[:
num_prefill_tokens] = PagedAttention.forward_prefix(
......
......@@ -43,7 +43,7 @@ from vllm.logits_process import get_bad_words_logits_processors
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.sampler import SampleRecorder, SamplerOutput, get_last_sampler
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory)
......@@ -414,7 +414,6 @@ class LLMEngine:
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
if self.zero_overhead:
assert os.environ.get('HIP_ALLOC_INITIALIZE') == '0'
self.async_d2h = None
self.last_record = None
self.async_event = torch.cuda.Event(enable_timing=False)
......@@ -1246,12 +1245,13 @@ class LLMEngine:
def _fix_last_step(
self, output: List[SamplerOutput],
last_sampler: SampleRecorder,
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
sample_out_list = self.async_d2h.tolist()
sample_out_ids = output[0].sampler_out_ids.tolist()
sample_out_ids = last_sampler.seq_ids
for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
zip(seq_group_metadata_list, output[0], scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group
......@@ -1339,12 +1339,9 @@ class LLMEngine:
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
last_outputs_ids = None
last_outputs_tensor = None
if self.last_record is not None:
last_output = self.last_record[0][0]
last_outputs_ids, last_outputs_tensor = last_output.sampler_out_ids, last_output.sampler_out_tenosr
self.async_d2h = last_outputs_tensor.to('cpu', non_blocking=True)
last_sampler = self.last_record[1]
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
self.async_event.record()
self.q_recorder.put(self.last_record)
else:
......@@ -1371,9 +1368,7 @@ class LLMEngine:
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids,
last_outputs_ids = last_outputs_ids,
last_outputs_sample = last_outputs_tensor)
last_sampled_token_ids=last_sampled_token_ids)
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
......@@ -1383,7 +1378,8 @@ class LLMEngine:
outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
self.last_record = [outputs, seq_group_metadata_list, scheduler_outputs]
last_sampler = get_last_sampler()
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs]
except Exception as e:
print(f"thread_zero_overhead error : {e}")
......@@ -1402,12 +1398,12 @@ class LLMEngine:
virtual_engine = 0
ctx = self.scheduler_contexts[virtual_engine]
ctx.request_outputs.clear()
outputs, seq_group_metadata_list, scheduler_outputs = recode_output
outputs, last_sampler, seq_group_metadata_list, scheduler_outputs = recode_output
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
self.async_event.synchronize()
self._fix_last_step(
outputs, seq_group_metadata_list,
outputs, last_sampler, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# is_first_step_output is True only when the num_steps of all
......
......@@ -1412,7 +1412,6 @@ class LLM:
if use_tqdm:
pbar.close()
self.llm_engine.finish_thread()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
......
......@@ -70,15 +70,19 @@ class SampleResultArgsType:
sampling_metadata: SamplingMetadata
greedy_samples: Optional[torch.Tensor]
beam_search_logprobs: Optional[torch.Tensor]
# Implemented by guanyu
@dataclass
class SampleDeviceToDevices:
class SampleRecorder:
def __init__(self):
self.seq_id:torch.Tensor = None
self.seq_ids:torch.Tensor = None
self.sampled_token_ids_tensor:torch.Tensor = None
self.zero_overhead:bool = False
d2d_data = SampleDeviceToDevices()
last_sampler = None
zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
def get_last_sampler():
return last_sampler
# Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling)
......@@ -214,8 +218,6 @@ class Sampler(nn.Module):
# speculative decoding.
self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
d2d_data.zero_overhead = self.zero_overhead
def _init_sampling_tensors(
self,
......@@ -266,6 +268,8 @@ class Sampler(nn.Module):
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
global last_sampler
last_sampler = SampleRecorder()
assert logits is not None
_, vocab_size = logits.shape
......@@ -476,7 +480,7 @@ def _greedy_sample(
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
if not d2d_data.zero_overhead:
if not zero_overhead:
samples_lst = samples.tolist()
sample_idx = 0
results: SampleResultType = []
......@@ -490,7 +494,7 @@ def _greedy_sample(
assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs))
if d2d_data.zero_overhead:
if zero_overhead:
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] #place holder token id
else:
......@@ -517,7 +521,7 @@ def _random_sample(
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum n value of the prompt phase requests.
if not d2d_data.zero_overhead:
if not zero_overhead:
random_samples = random_samples.cpu()
sample_idx = 0
results: SampleResultType = []
......@@ -533,7 +537,7 @@ def _random_sample(
if is_prompt:
# Prompt phase.
parent_ids = [0] * sampling_params.n
if d2d_data.zero_overhead:
if zero_overhead:
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * sampling_params.n #place holder token id
else:
......@@ -542,7 +546,7 @@ def _random_sample(
else:
# Generation phase.
parent_ids = list(range(num_parent_seqs))
if d2d_data.zero_overhead:
if zero_overhead:
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * num_parent_seqs #place holder token id
else:
......@@ -763,10 +767,10 @@ def _sample_with_torch(
t: []
for t in SamplingType
}
d2d_data.seq_id = torch.zeros(len(sampling_metadata.seq_groups), dtype=torch.int32)
last_sampler.seq_ids = []
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
d2d_data.seq_id[i] = seq_group.seq_ids[0]
last_sampler.seq_ids.append(seq_group.seq_ids[0])
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
......@@ -801,8 +805,7 @@ def _sample_with_torch(
greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1)
if d2d_data.zero_overhead:
d2d_data.sampled_token_ids_tensor = greedy_samples.unsqueeze(-1)
last_sampler.sampled_token_ids_tensor = greedy_samples.unsqueeze(-1)
if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
......@@ -841,9 +844,8 @@ def _sample_with_torch(
max_n_in_batch,
seq_groups=seq_groups_arg)
if d2d_data.zero_overhead:
d2d_data.sampled_token_ids_tensor = \
multinomial_samples[sampling_type].to(torch.long)
last_sampler.sampled_token_ids_tensor = \
multinomial_samples[sampling_type].to(torch.long)
if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
......@@ -1308,9 +1310,7 @@ def _build_sampler_output(
sampled_token_ids=sampled_token_ids,
logprobs=logprobs_tensor,
deferred_sample_results_args=deferred_sample_results_args,
logits=logits,
sampler_out_tenosr = d2d_data.sampled_token_ids_tensor,
sampler_out_ids = d2d_data.seq_id)
logits=logits)
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
......
import torch
import triton
import triton.language as tl
@triton.jit
def _update_input_tokens(
sample_output,
seq_ids,
input_tokens,
input_seq_ids,
BATCH_SIZE1,
BATCH_SIZE2,
):
pid = tl.program_id(0)
if pid >= BATCH_SIZE2:
return
output_token = tl.load(input_tokens + pid)
_input_seq_id = tl.load(input_seq_ids + pid)
for i in range(BATCH_SIZE1):
_seq_ids = tl.load(seq_ids + i)
if _seq_ids == _input_seq_id:
output_token = tl.load(sample_output + i)
tl.store(input_tokens + pid, output_token)
def UpdateInputTokens(input_tokens, input_seq_ids, last_sample, last_ids):
grid = [input_seq_ids.shape[0], 1, 1]
_update_input_tokens[grid](last_sample, last_ids, input_tokens, input_seq_ids, last_ids.shape[0], input_seq_ids.shape[0])
\ No newline at end of file
......@@ -1465,12 +1465,6 @@ class ExecuteModelRequest(
# Optional slot mapping of kvcache that pending to be moved generated from draft model.
kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
# for zero-overhead scheduler
last_outputs_sample : Optional[torch.Tensor] = None
# for zero-overhead scheduler
last_outputs_ids : Optional[torch.Tensor] = None
@property
def is_first_multi_step(self) -> bool:
# TODO(will) make this be able to handle batches with variable number of
......@@ -1520,9 +1514,7 @@ class ExecuteModelRequest(
async_callback=self.async_callback,
tree_attn_masks=self.tree_attn_masks,
tree_position_ids=self.tree_position_ids,
kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved,
last_outputs_sample = self.last_outputs_sample,
last_outputs_ids = self.last_outputs_ids)
kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved)
@dataclass
......
......@@ -69,11 +69,6 @@ class AsyncMetricsCollector:
self._in_flight_copy: Optional[torch.cuda.Event] = None
pin_memory = is_pin_memory_available()
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_draft_tokens = 0
self._rejsample_metrics_collect_interval_s = collect_interval_s
......@@ -88,10 +83,17 @@ class AsyncMetricsCollector:
device_type: Union[torch.device, str] = 'cuda') -> None:
self._rank = rank
if isinstance(device_type, torch.device):
torch.cuda.set_device(device_type)
device_type = device_type.type
if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream()
pin_memory = is_pin_memory_available()
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# currently using cuda.Event, skip for any non_cuda_alike platform
......
......@@ -30,13 +30,11 @@ class TargetModelRunner(ModelRunnerWrapperBase):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
finished_requests_ids: Optional[List[str]] = None
) -> ModelRunnerInputBase:
model_input: ModelRunnerInputBase =\
self.model_runner.prepare_model_input(
seq_group_metadata_list, virtual_engine, finished_requests_ids, last_outputs_ids, last_output_sample)
seq_group_metadata_list, virtual_engine, finished_requests_ids)
# If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the
......
......@@ -36,7 +36,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput, get_last_sampler
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal
......@@ -61,8 +61,6 @@ from vllm.worker.model_runner_base import (
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
from vllm.model_executor.layers.update_input import UpdateInputTokens
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
......@@ -479,14 +477,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.is_encoder_decoder_model = self.runner.model_config.is_encoder_decoder
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
self.last_sample_tensor = None
self.last_sample_ids = None
self.req_ids = []
def SetLastSamperData(self, last_sample_ids, last_sample_tensor):
self.last_sample_tensor = last_sample_tensor
self.last_sample_ids = last_sample_ids
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.finished_requests_ids = finished_requests_ids
......@@ -915,14 +907,25 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.runner.device,
self.runner.pin_memory)
if self.zero_overhead and self.last_sample_tensor is not None:
input_ids = async_tensor_h2d(self.req_ids, torch.long,
self.runner.device,
self.runner.pin_memory)
last_ids = async_tensor_h2d(self.last_sample_ids.tolist(), torch.long,
self.runner.device,
self.runner.pin_memory)
UpdateInputTokens(input_tokens_tensor, input_ids, self.last_sample_tensor, last_ids)
if self.zero_overhead:
last_sampler = get_last_sampler()
if last_sampler is not None:
update_indices = []
select_indices = []
for i, seq_id in enumerate(self.req_ids):
for j, seq_id_ in enumerate(last_sampler.seq_ids):
if seq_id == seq_id_:
select_indices.append(j)
update_indices.append(i)
break
if len(select_indices) > 0:
select_indices = async_tensor_h2d(select_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
input_tokens_tensor[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0]
token_types_tensor = async_tensor_h2d(token_types, torch.long,
self.runner.device,
......@@ -1225,9 +1228,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def _prepare_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None,
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
finished_requests_ids: Optional[List[str]] = None
) -> TModelInputForGPU:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
......@@ -1248,7 +1249,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.builder.add_seq_group(seq_group_metadata)
self.builder.reset_cached_inter_data()
self.builder.SetLastSamperData(last_outputs_ids, last_output_sample)
return self.builder.build() # type: ignore
@contextmanager
......@@ -1642,9 +1642,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
......@@ -1660,7 +1658,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
If cuda graph is required, this API automatically pads inputs.
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids, last_outputs_ids, last_output_sample)
seq_group_metadata_list, finished_requests_ids)
if get_pp_group().is_last_rank:
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
......
......@@ -209,9 +209,7 @@ class ModelRunnerBase(ABC, Generic[T]):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
finished_requests_ids: Optional[List[str]] = None
) -> T:
"""
Prepare the inputs to ModelRunnerBase.execute_model from an execution
......
......@@ -374,9 +374,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids,
last_outputs_ids = execute_model_req.last_outputs_ids,
last_output_sample = execute_model_req.last_outputs_sample))
execute_model_req.finished_requests_ids))
if self.tree_decoding and execute_model_req.tree_position_ids is not None and \
execute_model_req.tree_attn_masks is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment