Commit ffa31925 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.7.2-dev' into v0.7.2-fusion

parents 9e813a0e a9267f52
...@@ -38,7 +38,7 @@ from transformers import PreTrainedTokenizerBase ...@@ -38,7 +38,7 @@ from transformers import PreTrainedTokenizerBase
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
import triton # import triton
try: try:
...@@ -46,23 +46,23 @@ try: ...@@ -46,23 +46,23 @@ try:
except ImportError: except ImportError:
from backend_request_func import get_tokenizer from backend_request_func import get_tokenizer
triton_version = triton.__version__ # triton_version = triton.__version__
if triton_version.startswith("2.1"): # if triton_version.startswith("2.1"):
from triton.common.backend import compute_core_version_key # from triton.common.backend import compute_core_version_key
elif triton_version.startswith("3.0"): # elif triton_version.startswith("3.0"):
from triton.compiler.compiler import triton_key # from triton.compiler.compiler import triton_key
else: # else:
print(f"TRITON version {triton_version} is not specifically handled.") # print(f"TRITON version {triton_version} is not specifically handled.")
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501 PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
def test_prefix(llm=None, sampling_params=None, prompts=None): def test_prefix(llm=None, sampling_params=None, prompts=None):
if triton_version.startswith("2.1"): # if triton_version.startswith("2.1"):
version_key = compute_core_version_key() # version_key = compute_core_version_key()
if triton_version.startswith("3.0"): # if triton_version.startswith("3.0"):
version_key = triton_key() # version_key = triton_key()
start_time = time.time() start_time = time.time()
llm.generate(prompts, sampling_params=sampling_params) llm.generate(prompts, sampling_params=sampling_params)
......
...@@ -488,11 +488,11 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -488,11 +488,11 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
if (major, minor) == ('2', '4'): if (major, minor) == ('2', '4'):
version = 'das.opt2.alpha.' + sha[:7] version = 'das.opt1.' + sha[:7]
# version = 'das.opt1.' + sha[:7] # version = 'das.opt1.' + sha[:7]
else: else:
if (major, minor) == ('2', '4'): if (major, minor) == ('2', '4'):
version = 'das.opt2.alpha' version = 'das.opt1'
# version = 'das.opt1' # version = 'das.opt1'
......
...@@ -5,6 +5,8 @@ from dataclasses import dataclass ...@@ -5,6 +5,8 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch import torch
import triton
from triton.compiler.compiler import triton_key
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -778,6 +780,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -778,6 +780,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else: else:
# prefix-enabled attention # prefix-enabled attention
# not applicable for encoder-only models # not applicable for encoder-only models
version_key = triton_key()
if self.attn_type != AttentionType.ENCODER_ONLY: if self.attn_type != AttentionType.ENCODER_ONLY:
output[: output[:
num_prefill_tokens] = PagedAttention.forward_prefix( num_prefill_tokens] = PagedAttention.forward_prefix(
......
...@@ -43,7 +43,7 @@ from vllm.logits_process import get_bad_words_logits_processors ...@@ -43,7 +43,7 @@ from vllm.logits_process import get_bad_words_logits_processors
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor) get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SampleRecorder, SamplerOutput, get_last_sampler
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import (PoolingRequestOutput, RequestOutput, from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
...@@ -414,7 +414,6 @@ class LLMEngine: ...@@ -414,7 +414,6 @@ class LLMEngine:
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1' self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
if self.zero_overhead: if self.zero_overhead:
assert os.environ.get('HIP_ALLOC_INITIALIZE') == '0'
self.async_d2h = None self.async_d2h = None
self.last_record = None self.last_record = None
self.async_event = torch.cuda.Event(enable_timing=False) self.async_event = torch.cuda.Event(enable_timing=False)
...@@ -1246,12 +1245,13 @@ class LLMEngine: ...@@ -1246,12 +1245,13 @@ class LLMEngine:
def _fix_last_step( def _fix_last_step(
self, output: List[SamplerOutput], self, output: List[SamplerOutput],
last_sampler: SampleRecorder,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist() #sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
sample_out_list = self.async_d2h.tolist() sample_out_list = self.async_d2h.tolist()
sample_out_ids = output[0].sampler_out_ids.tolist() sample_out_ids = last_sampler.seq_ids
for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \ for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
zip(seq_group_metadata_list, output[0], scheduled_seq_groups): zip(seq_group_metadata_list, output[0], scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
...@@ -1339,12 +1339,9 @@ class LLMEngine: ...@@ -1339,12 +1339,9 @@ class LLMEngine:
(seq_group_metadata_list, scheduler_outputs, (seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc allow_async_output_proc
) = self.scheduler[virtual_engine].schedule() ) = self.scheduler[virtual_engine].schedule()
last_outputs_ids = None
last_outputs_tensor = None
if self.last_record is not None: if self.last_record is not None:
last_output = self.last_record[0][0] last_sampler = self.last_record[1]
last_outputs_ids, last_outputs_tensor = last_output.sampler_out_ids, last_output.sampler_out_tenosr self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
self.async_d2h = last_outputs_tensor.to('cpu', non_blocking=True)
self.async_event.record() self.async_event.record()
self.q_recorder.put(self.last_record) self.q_recorder.put(self.last_record)
else: else:
...@@ -1371,9 +1368,7 @@ class LLMEngine: ...@@ -1371,9 +1368,7 @@ class LLMEngine:
finished_requests_ids=finished_requests_ids, finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids # We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input. # to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids, last_sampled_token_ids=last_sampled_token_ids)
last_outputs_ids = last_outputs_ids,
last_outputs_sample = last_outputs_tensor)
outputs = self.model_executor.execute_model( outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
...@@ -1383,7 +1378,8 @@ class LLMEngine: ...@@ -1383,7 +1378,8 @@ class LLMEngine:
outputs[0], seq_group_metadata_list, outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups) scheduler_outputs.scheduled_seq_groups)
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
self.last_record = [outputs, seq_group_metadata_list, scheduler_outputs] last_sampler = get_last_sampler()
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs]
except Exception as e: except Exception as e:
print(f"thread_zero_overhead error : {e}") print(f"thread_zero_overhead error : {e}")
...@@ -1402,12 +1398,12 @@ class LLMEngine: ...@@ -1402,12 +1398,12 @@ class LLMEngine:
virtual_engine = 0 virtual_engine = 0
ctx = self.scheduler_contexts[virtual_engine] ctx = self.scheduler_contexts[virtual_engine]
ctx.request_outputs.clear() ctx.request_outputs.clear()
outputs, seq_group_metadata_list, scheduler_outputs = recode_output outputs, last_sampler, seq_group_metadata_list, scheduler_outputs = recode_output
ctx.seq_group_metadata_list = seq_group_metadata_list ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs ctx.scheduler_outputs = scheduler_outputs
self.async_event.synchronize() self.async_event.synchronize()
self._fix_last_step( self._fix_last_step(
outputs, seq_group_metadata_list, outputs, last_sampler, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups) scheduler_outputs.scheduled_seq_groups)
# is_first_step_output is True only when the num_steps of all # is_first_step_output is True only when the num_steps of all
......
...@@ -1412,7 +1412,6 @@ class LLM: ...@@ -1412,7 +1412,6 @@ class LLM:
if use_tqdm: if use_tqdm:
pbar.close() pbar.close()
self.llm_engine.finish_thread()
# Sort the outputs by request ID. # Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than # This is necessary because some requests may be finished earlier than
# its previous requests. # its previous requests.
......
...@@ -70,15 +70,19 @@ class SampleResultArgsType: ...@@ -70,15 +70,19 @@ class SampleResultArgsType:
sampling_metadata: SamplingMetadata sampling_metadata: SamplingMetadata
greedy_samples: Optional[torch.Tensor] greedy_samples: Optional[torch.Tensor]
beam_search_logprobs: Optional[torch.Tensor] beam_search_logprobs: Optional[torch.Tensor]
# Implemented by guanyu
@dataclass
class SampleDeviceToDevices: class SampleRecorder:
def __init__(self): def __init__(self):
self.seq_id:torch.Tensor = None self.seq_ids:torch.Tensor = None
self.sampled_token_ids_tensor:torch.Tensor = None self.sampled_token_ids_tensor:torch.Tensor = None
self.zero_overhead:bool = False
d2d_data = SampleDeviceToDevices() last_sampler = None
zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
def get_last_sampler():
return last_sampler
# Union of non-deferred (single-step scheduling) # Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling) # vs deferred (multi-step scheduling)
...@@ -214,8 +218,6 @@ class Sampler(nn.Module): ...@@ -214,8 +218,6 @@ class Sampler(nn.Module):
# speculative decoding. # speculative decoding.
self.include_gpu_probs_tensor = False self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False self.should_modify_greedy_probs_inplace = False
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
d2d_data.zero_overhead = self.zero_overhead
def _init_sampling_tensors( def _init_sampling_tensors(
self, self,
...@@ -266,6 +268,8 @@ class Sampler(nn.Module): ...@@ -266,6 +268,8 @@ class Sampler(nn.Module):
logits: (num_tokens, vocab_size). logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling. sampling_metadata: Metadata for sampling.
""" """
global last_sampler
last_sampler = SampleRecorder()
assert logits is not None assert logits is not None
_, vocab_size = logits.shape _, vocab_size = logits.shape
...@@ -476,7 +480,7 @@ def _greedy_sample( ...@@ -476,7 +480,7 @@ def _greedy_sample(
same as the length of selected_seq_groups. If the corresponding same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], []) seq_group has do_sample=False, tuple contains ([], [])
""" """
if not d2d_data.zero_overhead: if not zero_overhead:
samples_lst = samples.tolist() samples_lst = samples.tolist()
sample_idx = 0 sample_idx = 0
results: SampleResultType = [] results: SampleResultType = []
...@@ -490,7 +494,7 @@ def _greedy_sample( ...@@ -490,7 +494,7 @@ def _greedy_sample(
assert num_parent_seqs == 1, ( assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.") "Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs)) parent_ids = list(range(num_parent_seqs))
if d2d_data.zero_overhead: if zero_overhead:
assert num_parent_seqs == 1 # not support muti seqences in seqence group assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] #place holder token id next_token_ids = [0] #place holder token id
else: else:
...@@ -517,7 +521,7 @@ def _random_sample( ...@@ -517,7 +521,7 @@ def _random_sample(
seq_group has do_sample=False, tuple contains ([], []) seq_group has do_sample=False, tuple contains ([], [])
""" """
# Find the maximum n value of the prompt phase requests. # Find the maximum n value of the prompt phase requests.
if not d2d_data.zero_overhead: if not zero_overhead:
random_samples = random_samples.cpu() random_samples = random_samples.cpu()
sample_idx = 0 sample_idx = 0
results: SampleResultType = [] results: SampleResultType = []
...@@ -533,7 +537,7 @@ def _random_sample( ...@@ -533,7 +537,7 @@ def _random_sample(
if is_prompt: if is_prompt:
# Prompt phase. # Prompt phase.
parent_ids = [0] * sampling_params.n parent_ids = [0] * sampling_params.n
if d2d_data.zero_overhead: if zero_overhead:
assert num_parent_seqs == 1 # not support muti seqences in seqence group assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * sampling_params.n #place holder token id next_token_ids = [0] * sampling_params.n #place holder token id
else: else:
...@@ -542,7 +546,7 @@ def _random_sample( ...@@ -542,7 +546,7 @@ def _random_sample(
else: else:
# Generation phase. # Generation phase.
parent_ids = list(range(num_parent_seqs)) parent_ids = list(range(num_parent_seqs))
if d2d_data.zero_overhead: if zero_overhead:
assert num_parent_seqs == 1 # not support muti seqences in seqence group assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * num_parent_seqs #place holder token id next_token_ids = [0] * num_parent_seqs #place holder token id
else: else:
...@@ -763,10 +767,10 @@ def _sample_with_torch( ...@@ -763,10 +767,10 @@ def _sample_with_torch(
t: [] t: []
for t in SamplingType for t in SamplingType
} }
d2d_data.seq_id = torch.zeros(len(sampling_metadata.seq_groups), dtype=torch.int32) last_sampler.seq_ids = []
categorized_sample_indices = sampling_metadata.categorized_sample_indices categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
d2d_data.seq_id[i] = seq_group.seq_ids[0] last_sampler.seq_ids.append(seq_group.seq_ids[0])
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
...@@ -801,8 +805,7 @@ def _sample_with_torch( ...@@ -801,8 +805,7 @@ def _sample_with_torch(
greedy_samples = torch.argmax(logprobs[long_sample_indices], greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1) dim=-1)
if d2d_data.zero_overhead: last_sampler.sampled_token_ids_tensor = greedy_samples.unsqueeze(-1)
d2d_data.sampled_token_ids_tensor = greedy_samples.unsqueeze(-1)
if sampled_token_ids_tensor is not None: if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor. # Store sampled tokens in output tensor.
...@@ -841,8 +844,7 @@ def _sample_with_torch( ...@@ -841,8 +844,7 @@ def _sample_with_torch(
max_n_in_batch, max_n_in_batch,
seq_groups=seq_groups_arg) seq_groups=seq_groups_arg)
if d2d_data.zero_overhead: last_sampler.sampled_token_ids_tensor = \
d2d_data.sampled_token_ids_tensor = \
multinomial_samples[sampling_type].to(torch.long) multinomial_samples[sampling_type].to(torch.long)
if sampled_token_ids_tensor is not None: if sampled_token_ids_tensor is not None:
...@@ -1308,9 +1310,7 @@ def _build_sampler_output( ...@@ -1308,9 +1310,7 @@ def _build_sampler_output(
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
logprobs=logprobs_tensor, logprobs=logprobs_tensor,
deferred_sample_results_args=deferred_sample_results_args, deferred_sample_results_args=deferred_sample_results_args,
logits=logits, logits=logits)
sampler_out_tenosr = d2d_data.sampled_token_ids_tensor,
sampler_out_ids = d2d_data.seq_id)
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
......
import torch
import triton
import triton.language as tl
@triton.jit
def _update_input_tokens(
sample_output,
seq_ids,
input_tokens,
input_seq_ids,
BATCH_SIZE1,
BATCH_SIZE2,
):
pid = tl.program_id(0)
if pid >= BATCH_SIZE2:
return
output_token = tl.load(input_tokens + pid)
_input_seq_id = tl.load(input_seq_ids + pid)
for i in range(BATCH_SIZE1):
_seq_ids = tl.load(seq_ids + i)
if _seq_ids == _input_seq_id:
output_token = tl.load(sample_output + i)
tl.store(input_tokens + pid, output_token)
def UpdateInputTokens(input_tokens, input_seq_ids, last_sample, last_ids):
grid = [input_seq_ids.shape[0], 1, 1]
_update_input_tokens[grid](last_sample, last_ids, input_tokens, input_seq_ids, last_ids.shape[0], input_seq_ids.shape[0])
\ No newline at end of file
...@@ -1465,12 +1465,6 @@ class ExecuteModelRequest( ...@@ -1465,12 +1465,6 @@ class ExecuteModelRequest(
# Optional slot mapping of kvcache that pending to be moved generated from draft model. # Optional slot mapping of kvcache that pending to be moved generated from draft model.
kvcache_slot_to_be_moved: Optional[torch.Tensor] = None kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
# for zero-overhead scheduler
last_outputs_sample : Optional[torch.Tensor] = None
# for zero-overhead scheduler
last_outputs_ids : Optional[torch.Tensor] = None
@property @property
def is_first_multi_step(self) -> bool: def is_first_multi_step(self) -> bool:
# TODO(will) make this be able to handle batches with variable number of # TODO(will) make this be able to handle batches with variable number of
...@@ -1520,9 +1514,7 @@ class ExecuteModelRequest( ...@@ -1520,9 +1514,7 @@ class ExecuteModelRequest(
async_callback=self.async_callback, async_callback=self.async_callback,
tree_attn_masks=self.tree_attn_masks, tree_attn_masks=self.tree_attn_masks,
tree_position_ids=self.tree_position_ids, tree_position_ids=self.tree_position_ids,
kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved, kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved)
last_outputs_sample = self.last_outputs_sample,
last_outputs_ids = self.last_outputs_ids)
@dataclass @dataclass
......
...@@ -69,11 +69,6 @@ class AsyncMetricsCollector: ...@@ -69,11 +69,6 @@ class AsyncMetricsCollector:
self._in_flight_copy: Optional[torch.cuda.Event] = None self._in_flight_copy: Optional[torch.cuda.Event] = None
pin_memory = is_pin_memory_available()
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_draft_tokens = 0 self._aggregate_num_draft_tokens = 0
self._rejsample_metrics_collect_interval_s = collect_interval_s self._rejsample_metrics_collect_interval_s = collect_interval_s
...@@ -88,10 +83,17 @@ class AsyncMetricsCollector: ...@@ -88,10 +83,17 @@ class AsyncMetricsCollector:
device_type: Union[torch.device, str] = 'cuda') -> None: device_type: Union[torch.device, str] = 'cuda') -> None:
self._rank = rank self._rank = rank
if isinstance(device_type, torch.device): if isinstance(device_type, torch.device):
torch.cuda.set_device(device_type)
device_type = device_type.type device_type = device_type.type
if device_type == 'cuda': if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream() self._copy_stream = torch.cuda.Stream()
pin_memory = is_pin_memory_available()
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
def maybe_collect_rejsample_metrics( def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]: self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# currently using cuda.Event, skip for any non_cuda_alike platform # currently using cuda.Event, skip for any non_cuda_alike platform
......
...@@ -30,13 +30,11 @@ class TargetModelRunner(ModelRunnerWrapperBase): ...@@ -30,13 +30,11 @@ class TargetModelRunner(ModelRunnerWrapperBase):
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None, finished_requests_ids: Optional[List[str]] = None
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
) -> ModelRunnerInputBase: ) -> ModelRunnerInputBase:
model_input: ModelRunnerInputBase =\ model_input: ModelRunnerInputBase =\
self.model_runner.prepare_model_input( self.model_runner.prepare_model_input(
seq_group_metadata_list, virtual_engine, finished_requests_ids, last_outputs_ids, last_output_sample) seq_group_metadata_list, virtual_engine, finished_requests_ids)
# If token log probabilities is disabled then skip generating sampler # If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors # CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the # as needed. If log probabilities is enabled then synchronize all the
......
...@@ -36,7 +36,7 @@ from vllm.lora.request import LoRARequest ...@@ -36,7 +36,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_last_sampler
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models import supports_lora, supports_multimodal
...@@ -61,8 +61,6 @@ from vllm.worker.model_runner_base import ( ...@@ -61,8 +61,6 @@ from vllm.worker.model_runner_base import (
_init_attn_metadata_from_tensor_dict, _init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict) _init_sampling_metadata_from_tensor_dict)
from vllm.model_executor.layers.update_input import UpdateInputTokens
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
...@@ -479,14 +477,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -479,14 +477,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.is_encoder_decoder_model = self.runner.model_config.is_encoder_decoder self.is_encoder_decoder_model = self.runner.model_config.is_encoder_decoder
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1' self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
self.last_sample_tensor = None
self.last_sample_ids = None
self.req_ids = [] self.req_ids = []
def SetLastSamperData(self, last_sample_ids, last_sample_tensor):
self.last_sample_tensor = last_sample_tensor
self.last_sample_ids = last_sample_ids
def prepare(self, def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None: finished_requests_ids: Optional[List[str]] = None) -> None:
self.finished_requests_ids = finished_requests_ids self.finished_requests_ids = finished_requests_ids
...@@ -915,14 +907,25 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -915,14 +907,25 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
if self.zero_overhead and self.last_sample_tensor is not None: if self.zero_overhead:
input_ids = async_tensor_h2d(self.req_ids, torch.long, last_sampler = get_last_sampler()
if last_sampler is not None:
update_indices = []
select_indices = []
for i, seq_id in enumerate(self.req_ids):
for j, seq_id_ in enumerate(last_sampler.seq_ids):
if seq_id == seq_id_:
select_indices.append(j)
update_indices.append(i)
break
if len(select_indices) > 0:
select_indices = async_tensor_h2d(select_indices, torch.long,
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
last_ids = async_tensor_h2d(self.last_sample_ids.tolist(), torch.long, update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
UpdateInputTokens(input_tokens_tensor, input_ids, self.last_sample_tensor, last_ids) input_tokens_tensor[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0]
token_types_tensor = async_tensor_h2d(token_types, torch.long, token_types_tensor = async_tensor_h2d(token_types, torch.long,
self.runner.device, self.runner.device,
...@@ -1225,9 +1228,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1225,9 +1228,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def _prepare_model_input_tensors( def _prepare_model_input_tensors(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None, finished_requests_ids: Optional[List[str]] = None
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
) -> TModelInputForGPU: ) -> TModelInputForGPU:
"""Helper method to prepare the model input based on a given sequence """Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not group. Prepares metadata needed for the base model forward pass but not
...@@ -1248,7 +1249,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1248,7 +1249,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.builder.add_seq_group(seq_group_metadata) self.builder.add_seq_group(seq_group_metadata)
self.builder.reset_cached_inter_data() self.builder.reset_cached_inter_data()
self.builder.SetLastSamperData(last_outputs_ids, last_output_sample)
return self.builder.build() # type: ignore return self.builder.build() # type: ignore
@contextmanager @contextmanager
...@@ -1642,9 +1642,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1642,9 +1642,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None, finished_requests_ids: Optional[List[str]] = None
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
) -> ModelInputForGPUWithSamplingMetadata: ) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including """Prepare the model input based on a given sequence group, including
metadata for the sampling step. metadata for the sampling step.
...@@ -1660,7 +1658,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1660,7 +1658,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
If cuda graph is required, this API automatically pads inputs. If cuda graph is required, this API automatically pads inputs.
""" """
model_input = self._prepare_model_input_tensors( model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids, last_outputs_ids, last_output_sample) seq_group_metadata_list, finished_requests_ids)
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
# Sampling metadata is only required for the final pp group # Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids) generators = self.get_generators(finished_requests_ids)
......
...@@ -209,9 +209,7 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -209,9 +209,7 @@ class ModelRunnerBase(ABC, Generic[T]):
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None, finished_requests_ids: Optional[List[str]] = None
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
) -> T: ) -> T:
""" """
Prepare the inputs to ModelRunnerBase.execute_model from an execution Prepare the inputs to ModelRunnerBase.execute_model from an execution
......
...@@ -374,9 +374,7 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -374,9 +374,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
self.model_runner.prepare_model_input( self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list, execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine, execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids, execute_model_req.finished_requests_ids))
last_outputs_ids = execute_model_req.last_outputs_ids,
last_output_sample = execute_model_req.last_outputs_sample))
if self.tree_decoding and execute_model_req.tree_position_ids is not None and \ if self.tree_decoding and execute_model_req.tree_position_ids is not None and \
execute_model_req.tree_attn_masks is not None: execute_model_req.tree_attn_masks is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment