Commit 3812059e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'origin/v0.7.2-dev-zero-overhead' into v0.7.2-dev

parents 07bcd2d4 333e3374
......@@ -43,7 +43,7 @@ from vllm.logits_process import get_bad_words_logits_processors
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.sampler import SampleRecorder, SamplerOutput, get_last_sampler
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory)
......@@ -414,7 +414,6 @@ class LLMEngine:
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
if self.zero_overhead:
assert os.environ.get('HIP_ALLOC_INITIALIZE') == '0'
self.async_d2h = None
self.last_record = None
self.async_event = torch.cuda.Event(enable_timing=False)
......@@ -1246,12 +1245,13 @@ class LLMEngine:
def _fix_last_step(
self, output: List[SamplerOutput],
last_sampler: SampleRecorder,
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
sample_out_list = self.async_d2h.tolist()
sample_out_ids = output[0].sampler_out_ids.tolist()
sample_out_ids = last_sampler.seq_ids
for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
zip(seq_group_metadata_list, output[0], scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group
......@@ -1339,12 +1339,9 @@ class LLMEngine:
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
last_outputs_ids = None
last_outputs_tensor = None
if self.last_record is not None:
last_output = self.last_record[0][0]
last_outputs_ids, last_outputs_tensor = last_output.sampler_out_ids, last_output.sampler_out_tenosr
self.async_d2h = last_outputs_tensor.to('cpu', non_blocking=True)
last_sampler = self.last_record[1]
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
self.async_event.record()
self.q_recorder.put(self.last_record)
else:
......@@ -1371,9 +1368,7 @@ class LLMEngine:
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids,
last_outputs_ids = last_outputs_ids,
last_outputs_sample = last_outputs_tensor)
last_sampled_token_ids=last_sampled_token_ids)
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
......@@ -1383,7 +1378,8 @@ class LLMEngine:
outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
self.last_record = [outputs, seq_group_metadata_list, scheduler_outputs]
last_sampler = get_last_sampler()
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs]
except Exception as e:
print(f"thread_zero_overhead error : {e}")
......@@ -1402,12 +1398,12 @@ class LLMEngine:
virtual_engine = 0
ctx = self.scheduler_contexts[virtual_engine]
ctx.request_outputs.clear()
outputs, seq_group_metadata_list, scheduler_outputs = recode_output
outputs, last_sampler, seq_group_metadata_list, scheduler_outputs = recode_output
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
self.async_event.synchronize()
self._fix_last_step(
outputs, seq_group_metadata_list,
outputs, last_sampler, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# is_first_step_output is True only when the num_steps of all
......
......@@ -1412,7 +1412,6 @@ class LLM:
if use_tqdm:
pbar.close()
self.llm_engine.finish_thread()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
......
......@@ -70,15 +70,19 @@ class SampleResultArgsType:
sampling_metadata: SamplingMetadata
greedy_samples: Optional[torch.Tensor]
beam_search_logprobs: Optional[torch.Tensor]
# Implemented by guanyu
@dataclass
class SampleDeviceToDevices:
class SampleRecorder:
def __init__(self):
self.seq_id:torch.Tensor = None
self.seq_ids:torch.Tensor = None
self.sampled_token_ids_tensor:torch.Tensor = None
self.zero_overhead:bool = False
d2d_data = SampleDeviceToDevices()
last_sampler = None
zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
def get_last_sampler():
return last_sampler
# Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling)
......@@ -214,8 +218,6 @@ class Sampler(nn.Module):
# speculative decoding.
self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
d2d_data.zero_overhead = self.zero_overhead
def _init_sampling_tensors(
self,
......@@ -266,6 +268,8 @@ class Sampler(nn.Module):
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
global last_sampler
last_sampler = SampleRecorder()
assert logits is not None
_, vocab_size = logits.shape
......@@ -476,7 +480,7 @@ def _greedy_sample(
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
if not d2d_data.zero_overhead:
if not zero_overhead:
samples_lst = samples.tolist()
sample_idx = 0
results: SampleResultType = []
......@@ -490,7 +494,7 @@ def _greedy_sample(
assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs))
if d2d_data.zero_overhead:
if zero_overhead:
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] #place holder token id
else:
......@@ -517,7 +521,7 @@ def _random_sample(
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum n value of the prompt phase requests.
if not d2d_data.zero_overhead:
if not zero_overhead:
random_samples = random_samples.cpu()
sample_idx = 0
results: SampleResultType = []
......@@ -533,7 +537,7 @@ def _random_sample(
if is_prompt:
# Prompt phase.
parent_ids = [0] * sampling_params.n
if d2d_data.zero_overhead:
if zero_overhead:
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * sampling_params.n #place holder token id
else:
......@@ -542,7 +546,7 @@ def _random_sample(
else:
# Generation phase.
parent_ids = list(range(num_parent_seqs))
if d2d_data.zero_overhead:
if zero_overhead:
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * num_parent_seqs #place holder token id
else:
......@@ -763,10 +767,10 @@ def _sample_with_torch(
t: []
for t in SamplingType
}
d2d_data.seq_id = torch.zeros(len(sampling_metadata.seq_groups), dtype=torch.int32)
last_sampler.seq_ids = []
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
d2d_data.seq_id[i] = seq_group.seq_ids[0]
last_sampler.seq_ids.append(seq_group.seq_ids[0])
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
......@@ -801,8 +805,7 @@ def _sample_with_torch(
greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1)
if d2d_data.zero_overhead:
d2d_data.sampled_token_ids_tensor = greedy_samples.unsqueeze(-1)
last_sampler.sampled_token_ids_tensor = greedy_samples.unsqueeze(-1)
if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
......@@ -841,8 +844,7 @@ def _sample_with_torch(
max_n_in_batch,
seq_groups=seq_groups_arg)
if d2d_data.zero_overhead:
d2d_data.sampled_token_ids_tensor = \
last_sampler.sampled_token_ids_tensor = \
multinomial_samples[sampling_type].to(torch.long)
if sampled_token_ids_tensor is not None:
......@@ -1308,9 +1310,7 @@ def _build_sampler_output(
sampled_token_ids=sampled_token_ids,
logprobs=logprobs_tensor,
deferred_sample_results_args=deferred_sample_results_args,
logits=logits,
sampler_out_tenosr = d2d_data.sampled_token_ids_tensor,
sampler_out_ids = d2d_data.seq_id)
logits=logits)
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
......
import torch
import triton
import triton.language as tl
@triton.jit
def _update_input_tokens(
sample_output,
seq_ids,
input_tokens,
input_seq_ids,
BATCH_SIZE1,
BATCH_SIZE2,
):
pid = tl.program_id(0)
if pid >= BATCH_SIZE2:
return
output_token = tl.load(input_tokens + pid)
_input_seq_id = tl.load(input_seq_ids + pid)
for i in range(BATCH_SIZE1):
_seq_ids = tl.load(seq_ids + i)
if _seq_ids == _input_seq_id:
output_token = tl.load(sample_output + i)
tl.store(input_tokens + pid, output_token)
def UpdateInputTokens(input_tokens, input_seq_ids, last_sample, last_ids):
grid = [input_seq_ids.shape[0], 1, 1]
_update_input_tokens[grid](last_sample, last_ids, input_tokens, input_seq_ids, last_ids.shape[0], input_seq_ids.shape[0])
\ No newline at end of file
......@@ -1465,12 +1465,6 @@ class ExecuteModelRequest(
# Optional slot mapping of kvcache that pending to be moved generated from draft model.
kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
# for zero-overhead scheduler
last_outputs_sample : Optional[torch.Tensor] = None
# for zero-overhead scheduler
last_outputs_ids : Optional[torch.Tensor] = None
@property
def is_first_multi_step(self) -> bool:
# TODO(will) make this be able to handle batches with variable number of
......@@ -1520,9 +1514,7 @@ class ExecuteModelRequest(
async_callback=self.async_callback,
tree_attn_masks=self.tree_attn_masks,
tree_position_ids=self.tree_position_ids,
kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved,
last_outputs_sample = self.last_outputs_sample,
last_outputs_ids = self.last_outputs_ids)
kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved)
@dataclass
......
......@@ -30,13 +30,11 @@ class TargetModelRunner(ModelRunnerWrapperBase):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
finished_requests_ids: Optional[List[str]] = None
) -> ModelRunnerInputBase:
model_input: ModelRunnerInputBase =\
self.model_runner.prepare_model_input(
seq_group_metadata_list, virtual_engine, finished_requests_ids, last_outputs_ids, last_output_sample)
seq_group_metadata_list, virtual_engine, finished_requests_ids)
# If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the
......
......@@ -36,7 +36,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput, get_last_sampler
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal
......@@ -61,8 +61,6 @@ from vllm.worker.model_runner_base import (
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
from vllm.model_executor.layers.update_input import UpdateInputTokens
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
......@@ -479,14 +477,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.is_encoder_decoder_model = self.runner.model_config.is_encoder_decoder
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
self.last_sample_tensor = None
self.last_sample_ids = None
self.req_ids = []
def SetLastSamperData(self, last_sample_ids, last_sample_tensor):
self.last_sample_tensor = last_sample_tensor
self.last_sample_ids = last_sample_ids
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.finished_requests_ids = finished_requests_ids
......@@ -915,14 +907,25 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.runner.device,
self.runner.pin_memory)
if self.zero_overhead and self.last_sample_tensor is not None:
input_ids = async_tensor_h2d(self.req_ids, torch.long,
if self.zero_overhead:
last_sampler = get_last_sampler()
if last_sampler is not None:
update_indices = []
select_indices = []
for i, seq_id in enumerate(self.req_ids):
for j, seq_id_ in enumerate(last_sampler.seq_ids):
if seq_id == seq_id_:
select_indices.append(j)
update_indices.append(i)
break
if len(select_indices) > 0:
select_indices = async_tensor_h2d(select_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
last_ids = async_tensor_h2d(self.last_sample_ids.tolist(), torch.long,
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
UpdateInputTokens(input_tokens_tensor, input_ids, self.last_sample_tensor, last_ids)
input_tokens_tensor[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0]
token_types_tensor = async_tensor_h2d(token_types, torch.long,
self.runner.device,
......@@ -1225,9 +1228,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def _prepare_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None,
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
finished_requests_ids: Optional[List[str]] = None
) -> TModelInputForGPU:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
......@@ -1248,7 +1249,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.builder.add_seq_group(seq_group_metadata)
self.builder.reset_cached_inter_data()
self.builder.SetLastSamperData(last_outputs_ids, last_output_sample)
return self.builder.build() # type: ignore
@contextmanager
......@@ -1642,9 +1642,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
......@@ -1660,7 +1658,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
If cuda graph is required, this API automatically pads inputs.
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids, last_outputs_ids, last_output_sample)
seq_group_metadata_list, finished_requests_ids)
if get_pp_group().is_last_rank:
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
......
......@@ -209,9 +209,7 @@ class ModelRunnerBase(ABC, Generic[T]):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
finished_requests_ids: Optional[List[str]] = None
) -> T:
"""
Prepare the inputs to ModelRunnerBase.execute_model from an execution
......
......@@ -374,9 +374,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids,
last_outputs_ids = execute_model_req.last_outputs_ids,
last_output_sample = execute_model_req.last_outputs_sample))
execute_model_req.finished_requests_ids))
if self.tree_decoding and execute_model_req.tree_position_ids is not None and \
execute_model_req.tree_attn_masks is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment