Commit 4851c202 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.1' into v0.6.1-dev

parents 9b902f9e 3fd2b0d2
......@@ -13,7 +13,6 @@ except ModuleNotFoundError:
import torch
from vllm import _custom_ops as ops
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
......@@ -274,12 +273,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
self.pinned_sampled_token_ids)
if model_output.pythonized:
ctx = output_proc_callback.keywords["ctx"]
is_async = False
is_last_step = False
ctx.output_queue.append(
([model_output.sampler_output
], ctx.seq_group_metadata_list,
ctx.scheduler_outputs, is_async, is_last_step))
ctx.append_output(
outputs=[model_output.sampler_output],
seq_group_metadata_list=ctx.seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False)
output_proc_callback()
else:
cont = False
......@@ -319,12 +319,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
if not is_last_step:
ctx = output_proc_callback.keywords[ # type: ignore
"ctx"] # type: ignore
is_async = False
is_last_step = False
ctx.output_queue.append(
([output.sampler_output
], ctx.seq_group_metadata_list,
ctx.scheduler_outputs, is_async, is_last_step))
ctx.append_output(
outputs=[output.sampler_output],
seq_group_metadata_list=ctx.
seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False)
else:
outputs.append(output.sampler_output)
else:
......@@ -497,19 +498,11 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
attn_metadata = frozen_model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata)
attn_metadata.advance_step(num_seqs, num_queries)
# Update GPU tensors
ops.advance_step(
num_seqs=num_seqs,
num_queries=num_queries,
block_size=self.block_size,
input_tokens=frozen_model_input.input_tokens,
sampled_token_ids=model_input.cached_outputs[-1].sampled_token_ids,
input_positions=frozen_model_input.input_positions,
seq_lens=attn_metadata.seq_lens_tensor,
slot_mapping=attn_metadata.slot_mapping,
block_tables=attn_metadata.block_tables)
attn_metadata.advance_step(
frozen_model_input,
model_input.cached_outputs[-1].sampled_token_ids, self.block_size,
num_seqs, num_queries)
if frozen_model_input.seq_lens is not None:
for i in range(num_queries):
......
......@@ -11,7 +11,7 @@ import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.logger import init_logger
......@@ -611,7 +611,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
return [SamplerOutput(sampler_outputs)]
class ModelWrapper(TorchCompileWrapperWithCustomDispacther):
class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
def __init__(self, model: nn.Module):
self.model = model
......
......@@ -166,6 +166,7 @@ class Worker(LocalOrDistributedWorkerBase):
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment