Commit 2344d22e authored by lizhigong's avatar lizhigong
Browse files

use two thread in step too improve first tokens

parent b78549c2
...@@ -14,8 +14,6 @@ from vllm.attention.backends.abstract import AttentionType ...@@ -14,8 +14,6 @@ from vllm.attention.backends.abstract import AttentionType
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.profiler.prof import profile
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.model_runner_base import ModelRunnerBase
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
import os import os
import copy import copy
import time import time
import threading
import queue
from collections import Counter as collectionsCounter from collections import Counter as collectionsCounter
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
...@@ -410,10 +412,17 @@ class LLMEngine: ...@@ -410,10 +412,17 @@ class LLMEngine:
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1' self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
self.step_switch = 0 # 0 step A 1 step B if self.zero_overhead:
self.output_recorder = [None, None] # self.step_switch = 0 # 0 step A 1 step B
self.async_d2h = None # self.output_recorder = [None, None]
self.async_event = torch.cuda.Event(enable_timing=False) self.async_d2h = None
self.last_record = None
self.async_event = torch.cuda.Event(enable_timing=False)
self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
self.q_recorder = queue.Queue()
self.q_recorder.put(None) # None is use for first step ignore
self.sem_m2s = threading.Semaphore(0) # main to scheduler thread
self.zero_thread.start()
profile.StartTracer() profile.StartTracer()
def _initialize_kv_caches(self) -> None: def _initialize_kv_caches(self) -> None:
...@@ -1307,6 +1316,129 @@ class LLMEngine: ...@@ -1307,6 +1316,129 @@ class LLMEngine:
def trans_last_output_tensor(self, last_output) -> torch.Tensor: def trans_last_output_tensor(self, last_output) -> torch.Tensor:
return None return None
def thread_zero_overhead(self):
while True:
self.sem_m2s.acquire()
virtual_engine = 0
ctx = self.scheduler_contexts[virtual_engine]
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
last_outputs_ids = None
last_outputs_tensor = None
if self.last_record is not None:
last_output = self.last_record[0][0]
last_outputs_ids, last_outputs_tensor = last_output.sampler_out_ids, last_output.sampler_out_tenosr
self.async_d2h = last_outputs_tensor.to('cpu', non_blocking=True)
self.async_event.record()
self.q_recorder.put(self.last_record)
last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids,
last_outputs_ids = last_outputs_ids,
last_outputs_sample = last_outputs_tensor)
if allow_async_output_proc:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
#profile.ProfRangeAutoPush('model_executor')
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
self._advance_to_next_step(
outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
self.last_record = [outputs, seq_group_metadata_list, scheduler_outputs]
def zero_overhead_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
self.sem_m2s.release()
recode_output = self.q_recorder.get()
if recode_output is None: # None is for the first step
return None
virtual_engine = 0
ctx = self.scheduler_contexts[virtual_engine]
outputs, seq_group_metadata_list, scheduler_outputs = recode_output
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
self.async_event.synchronize()
self._fix_last_step(
outputs, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
allow_async_output_proc = True
if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps.
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[0] = SchedulerOutputState()
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
# Add results to the output_queue
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=allow_async_output_proc,
is_last_step=True,
is_first_step_output=is_first_step_output)
# Check if need to run the usual non-async path
if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)
# Log stats.
self.do_log_stats(scheduler_outputs, outputs)
# Tracing
self.do_tracing(scheduler_outputs)
else:
# Multi-step case
return ctx.request_outputs
#profile.ProfRangeAutoPush('has_unfinish')
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
return ctx.request_outputs
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
...@@ -1359,6 +1491,9 @@ class LLMEngine: ...@@ -1359,6 +1491,9 @@ class LLMEngine:
>>> if not (engine.has_unfinished_requests() or example_inputs): >>> if not (engine.has_unfinished_requests() or example_inputs):
>>> break >>> break
""" """
if self.zero_overhead:
return self.zero_overhead_step()
if self.parallel_config.pipeline_parallel_size > 1: if self.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError( raise NotImplementedError(
"Pipeline parallelism is only supported through AsyncLLMEngine " "Pipeline parallelism is only supported through AsyncLLMEngine "
...@@ -1383,7 +1518,6 @@ class LLMEngine: ...@@ -1383,7 +1518,6 @@ class LLMEngine:
# Skip the scheduler if there are any remaining steps in the seq groups. # Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current # This ensures that the scheduler is only called again when the current
# batch has completed. # batch has completed.
profile.ProfRangeAutoPush('has_remain')
if not self._has_remaining_steps(seq_group_metadata_list): if not self._has_remaining_steps(seq_group_metadata_list):
# Schedule iteration # Schedule iteration
(seq_group_metadata_list, scheduler_outputs, (seq_group_metadata_list, scheduler_outputs,
...@@ -1413,15 +1547,6 @@ class LLMEngine: ...@@ -1413,15 +1547,6 @@ class LLMEngine:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
assert scheduler_outputs is not None assert scheduler_outputs is not None
last_outputs_ids = None
last_outputs_tensor = None
if self.zero_overhead:
recode_output = self.output_recorder[1 - self.step_switch]
if recode_output is not None:
last_output = recode_output[0][0]
last_outputs_ids, last_outputs_tensor = last_output.sampler_out_ids, last_output.sampler_out_tenosr
self.async_d2h = last_outputs_tensor.to('cpu', non_blocking=True)
self.async_event.record()
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
# Check if we have a cached last_output from the previous iteration. # Check if we have a cached last_output from the previous iteration.
...@@ -1441,17 +1566,15 @@ class LLMEngine: ...@@ -1441,17 +1566,15 @@ class LLMEngine:
finished_requests_ids=finished_requests_ids, finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids # We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input. # to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids, last_sampled_token_ids=last_sampled_token_ids)
last_outputs_ids = last_outputs_ids,
last_outputs_sample = last_outputs_tensor)
if allow_async_output_proc: if allow_async_output_proc:
execute_model_req.async_callback = self.async_callbacks[ execute_model_req.async_callback = self.async_callbacks[
virtual_engine] virtual_engine]
profile.ProfRangeAutoPush('model_executor') #profile.ProfRangeAutoPush('model_executor')
outputs = self.model_executor.execute_model( outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
profile.ProfRangeAutoPush('end_executor') #profile.ProfRangeAutoPush('end_executor')
# We need to do this here so that last step's sampled_token_ids can # We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP. # be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
...@@ -1464,26 +1587,6 @@ class LLMEngine: ...@@ -1464,26 +1587,6 @@ class LLMEngine:
# No outputs in this case # No outputs in this case
outputs = [] outputs = []
if self.zero_overhead:
self.output_recorder[self.step_switch] = [outputs, seq_group_metadata_list, scheduler_outputs]
self._advance_to_next_step(
outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
self.step_switch = 1 - self.step_switch
recode_output = self.output_recorder[self.step_switch]
if recode_output is None:
return None
outputs, seq_group_metadata_list, scheduler_outputs = self.output_recorder[self.step_switch]
self.output_recorder[self.step_switch] = None # only use for once
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
self.async_event.synchronize()
self._fix_last_step(
outputs, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# Finish the current step for all the sequence groups. # Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
for seq_group in seq_group_metadata_list: for seq_group in seq_group_metadata_list:
...@@ -1511,10 +1614,9 @@ class LLMEngine: ...@@ -1511,10 +1614,9 @@ class LLMEngine:
if outputs and allow_async_output_proc: if outputs and allow_async_output_proc:
assert len(outputs) == 1, ( assert len(outputs) == 1, (
"Async postprocessor expects only a single output set") "Async postprocessor expects only a single output set")
if not self.zero_overhead: self._advance_to_next_step(
self._advance_to_next_step( outputs[0], seq_group_metadata_list,
outputs[0], seq_group_metadata_list, scheduler_outputs.scheduled_seq_groups)
scheduler_outputs.scheduled_seq_groups)
# Check if need to run the usual non-async path # Check if need to run the usual non-async path
if not allow_async_output_proc: if not allow_async_output_proc:
...@@ -1529,7 +1631,7 @@ class LLMEngine: ...@@ -1529,7 +1631,7 @@ class LLMEngine:
# Multi-step case # Multi-step case
return ctx.request_outputs return ctx.request_outputs
profile.ProfRangeAutoPush('has_unfinish') #profile.ProfRangeAutoPush('has_unfinish')
if not self.has_unfinished_requests(): if not self.has_unfinished_requests():
# Drain async postprocessor (if exists) # Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0: if len(ctx.output_queue) > 0:
......
...@@ -1388,7 +1388,6 @@ class LLM: ...@@ -1388,7 +1388,6 @@ class LLM:
total_out_toks = 0 total_out_toks = 0
while self.llm_engine.has_unfinished_requests(): while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step() step_outputs = self.llm_engine.step()
#print('###step_outputs', step_outputs)
if step_outputs is None: if step_outputs is None:
continue continue
for output in step_outputs: for output in step_outputs:
......
...@@ -32,7 +32,6 @@ if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): ...@@ -32,7 +32,6 @@ if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
# yapf: enable # yapf: enable
else: else:
flashinfer_top_k_top_p_sampling = None flashinfer_top_k_top_p_sampling = None
from vllm.profiler.prof import profile
def get_sampler() -> torch.nn.Module: def get_sampler() -> torch.nn.Module:
...@@ -267,7 +266,6 @@ class Sampler(nn.Module): ...@@ -267,7 +266,6 @@ class Sampler(nn.Module):
logits: (num_tokens, vocab_size). logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling. sampling_metadata: Metadata for sampling.
""" """
profile.ProfRangeAutoPush('sampler_forward')
assert logits is not None assert logits is not None
_, vocab_size = logits.shape _, vocab_size = logits.shape
...@@ -280,7 +278,6 @@ class Sampler(nn.Module): ...@@ -280,7 +278,6 @@ class Sampler(nn.Module):
# reuse sampling tensors, since "output_tokens" changes # reuse sampling tensors, since "output_tokens" changes
# between decode runs. # between decode runs.
self._init_sampling_tensors(logits, sampling_metadata) self._init_sampling_tensors(logits, sampling_metadata)
profile.ProfRangeAutoPush('sampler1')
assert self._sampling_tensors is not None assert self._sampling_tensors is not None
sampling_tensors = self._sampling_tensors sampling_tensors = self._sampling_tensors
......
...@@ -11,7 +11,6 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, ...@@ -11,7 +11,6 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.utils import (PyObjectCache, async_tensor_h2d, from vllm.utils import (PyObjectCache, async_tensor_h2d,
is_pin_memory_available, make_tensor_with_pad) is_pin_memory_available, make_tensor_with_pad)
from vllm.profiler.prof import profile
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
...@@ -512,7 +511,6 @@ class SamplingTensors: ...@@ -512,7 +511,6 @@ class SamplingTensors:
) -> "SamplingTensors": ) -> "SamplingTensors":
# Note that the performance will be very bad without # Note that the performance will be very bad without
# pinned memory. # pinned memory.
profile.ProfRangeAutoPush('from_lists')
pin_memory = is_pin_memory_available() pin_memory = is_pin_memory_available()
do_penalties = prompt_tokens or output_tokens do_penalties = prompt_tokens or output_tokens
...@@ -535,7 +533,6 @@ class SamplingTensors: ...@@ -535,7 +533,6 @@ class SamplingTensors:
empty_tensor = torch.empty(0, device=device, dtype=torch.long) empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_t = empty_tensor prompt_t = empty_tensor
output_t = empty_tensor output_t = empty_tensor
profile.ProfRangeAutoPush('from_lists1')
temperatures_t = torch.tensor( temperatures_t = torch.tensor(
temperatures, temperatures,
device="cpu", device="cpu",
...@@ -581,7 +578,6 @@ class SamplingTensors: ...@@ -581,7 +578,6 @@ class SamplingTensors:
# Because the memory is pinned, we can do non-blocking # Because the memory is pinned, we can do non-blocking
# transfer to device. # transfer to device.
profile.ProfRangeAutoPush('from_lists2')
return cls( return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True), temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True),
......
...@@ -583,9 +583,9 @@ class Sequence: ...@@ -583,9 +583,9 @@ class Sequence:
self.data.append_token_id(token_id, logprobs[token_id].logprob) self.data.append_token_id(token_id, logprobs[token_id].logprob)
def fix_last_token_id(self, token_id: int) -> None: def fix_last_token_id(self, token_id: int) -> None:
self.data._output_token_ids[-2] = token_id self.data._output_token_ids[-1] = token_id
self.data._new_appended_tokens[-2] = token_id self.data._new_appended_tokens[-1] = token_id
self.data._cached_all_token_ids[-2] = token_id self.data._cached_all_token_ids[-1] = token_id
def get_len(self) -> int: def get_len(self) -> int:
return self.data.get_len() return self.data.get_len()
......
...@@ -62,8 +62,6 @@ from vllm.worker.model_runner_base import ( ...@@ -62,8 +62,6 @@ from vllm.worker.model_runner_base import (
from vllm.model_executor.layers.ops.update_input import UpdateInputTokens from vllm.model_executor.layers.ops.update_input import UpdateInputTokens
from vllm.profiler.prof import profile
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
...@@ -841,7 +839,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -841,7 +839,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Combine and flatten intermediate data. # Combine and flatten intermediate data.
input_tokens = [] input_tokens = []
token_types = [] token_types = []
profile.ProfRangeAutoPush('build')
for inter_data in self.inter_data_list: for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens: for cur_input_tokens in inter_data.input_tokens:
input_tokens.extend(cur_input_tokens) input_tokens.extend(cur_input_tokens)
...@@ -1006,7 +1003,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -1006,7 +1003,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
] ]
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
profile.ProfRangeAutoPush('build_end')
return self.model_input_cls( return self.model_input_cls(
input_tokens=input_tokens_tensor, input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor, input_positions=input_positions_tensor,
...@@ -1700,7 +1696,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1700,7 +1696,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self.set_active_prompt_adapters( self.set_active_prompt_adapters(
model_input.prompt_adapter_requests, model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping) model_input.prompt_adapter_mapping)
profile.ProfRangeAutoPush('begin_forward')
self.attn_state.begin_forward(model_input) self.attn_state.begin_forward(model_input)
# Currently cuda graph is only supported by the decode phase. # Currently cuda graph is only supported by the decode phase.
...@@ -1802,7 +1797,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1802,7 +1797,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
torch.tensor(model_forward_time + orig_model_forward_time)) torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states return hidden_or_intermediate_states
profile.ProfRangeAutoPush('compute_logits')
logits = self.model.compute_logits(hidden_or_intermediate_states, logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata) model_input.sampling_metadata)
...@@ -1813,12 +1807,10 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1813,12 +1807,10 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_input.async_callback() model_input.async_callback()
# Sample the next token. # Sample the next token.
profile.ProfRangeAutoPush('sample')
output: SamplerOutput = self.model.sample( output: SamplerOutput = self.model.sample(
logits=logits, logits=logits,
sampling_metadata=model_input.sampling_metadata, sampling_metadata=model_input.sampling_metadata,
) )
profile.ProfRangeAutoPush('sample_end')
if (self.observability_config is not None if (self.observability_config is not None
and self.observability_config.collect_model_forward_time and self.observability_config.collect_model_forward_time
and output is not None): and output is not None):
...@@ -1836,7 +1828,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1836,7 +1828,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
output.model_forward_time = (orig_model_forward_time + output.model_forward_time = (orig_model_forward_time +
model_forward_time) model_forward_time)
profile.ProfRangeAutoPush('output')
if self.return_hidden_states: if self.return_hidden_states:
# we only need to pass hidden states of most recent token # we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None assert model_input.sampling_metadata is not None
......
...@@ -25,7 +25,6 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput, ...@@ -25,7 +25,6 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase, ModelRunnerBase,
ModelRunnerInputBase) ModelRunnerInputBase)
from vllm.profiler.prof import profile
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -447,7 +446,6 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -447,7 +446,6 @@ class LocalOrDistributedWorkerBase(WorkerBase):
and self.observability_config.collect_model_execute_time): and self.observability_config.collect_model_execute_time):
orig_model_execute_time = intermediate_tensors.tensors.get( orig_model_execute_time = intermediate_tensors.tensors.get(
"model_execute_time", torch.tensor(0)).item() "model_execute_time", torch.tensor(0)).item()
profile.ProfRangeAutoPush('execute')
output = self.model_runner.execute_model( output = self.model_runner.execute_model(
model_input=model_input, model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine] kv_caches=self.kv_cache[worker_input.virtual_engine]
...@@ -456,7 +454,6 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -456,7 +454,6 @@ class LocalOrDistributedWorkerBase(WorkerBase):
num_steps=num_steps, num_steps=num_steps,
**kwargs, **kwargs,
) )
profile.ProfRangeAutoPush('output')
model_execute_time = time.perf_counter() - start_time model_execute_time = time.perf_counter() - start_time
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment