Commit 4851c202 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.1' into v0.6.1-dev

parents 9b902f9e 3fd2b0d2
...@@ -13,7 +13,6 @@ except ModuleNotFoundError: ...@@ -13,7 +13,6 @@ except ModuleNotFoundError:
import torch import torch
from vllm import _custom_ops as ops
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs, from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
...@@ -274,12 +273,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -274,12 +273,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
self.pinned_sampled_token_ids) self.pinned_sampled_token_ids)
if model_output.pythonized: if model_output.pythonized:
ctx = output_proc_callback.keywords["ctx"] ctx = output_proc_callback.keywords["ctx"]
is_async = False ctx.append_output(
is_last_step = False outputs=[model_output.sampler_output],
ctx.output_queue.append( seq_group_metadata_list=ctx.seq_group_metadata_list,
([model_output.sampler_output scheduler_outputs=ctx.scheduler_outputs,
], ctx.seq_group_metadata_list, is_async=False,
ctx.scheduler_outputs, is_async, is_last_step)) is_last_step=False)
output_proc_callback() output_proc_callback()
else: else:
cont = False cont = False
...@@ -319,12 +319,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -319,12 +319,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
if not is_last_step: if not is_last_step:
ctx = output_proc_callback.keywords[ # type: ignore ctx = output_proc_callback.keywords[ # type: ignore
"ctx"] # type: ignore "ctx"] # type: ignore
is_async = False ctx.append_output(
is_last_step = False outputs=[output.sampler_output],
ctx.output_queue.append( seq_group_metadata_list=ctx.
([output.sampler_output seq_group_metadata_list,
], ctx.seq_group_metadata_list, scheduler_outputs=ctx.scheduler_outputs,
ctx.scheduler_outputs, is_async, is_last_step)) is_async=False,
is_last_step=False)
else: else:
outputs.append(output.sampler_output) outputs.append(output.sampler_output)
else: else:
...@@ -497,19 +498,11 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -497,19 +498,11 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
attn_metadata = frozen_model_input.attn_metadata attn_metadata = frozen_model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata) assert isinstance(attn_metadata, FlashAttentionMetadata)
attn_metadata.advance_step(num_seqs, num_queries)
attn_metadata.advance_step(
# Update GPU tensors frozen_model_input,
ops.advance_step( model_input.cached_outputs[-1].sampled_token_ids, self.block_size,
num_seqs=num_seqs, num_seqs, num_queries)
num_queries=num_queries,
block_size=self.block_size,
input_tokens=frozen_model_input.input_tokens,
sampled_token_ids=model_input.cached_outputs[-1].sampled_token_ids,
input_positions=frozen_model_input.input_positions,
seq_lens=attn_metadata.seq_lens_tensor,
slot_mapping=attn_metadata.slot_mapping,
block_tables=attn_metadata.block_tables)
if frozen_model_input.seq_lens is not None: if frozen_model_input.seq_lens is not None:
for i in range(num_queries): for i in range(num_queries):
......
...@@ -11,7 +11,7 @@ import torch_xla.core.xla_model as xm ...@@ -11,7 +11,7 @@ import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -611,7 +611,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -611,7 +611,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
return [SamplerOutput(sampler_outputs)] return [SamplerOutput(sampler_outputs)]
class ModelWrapper(TorchCompileWrapperWithCustomDispacther): class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
def __init__(self, model: nn.Module): def __init__(self, model: nn.Module):
self.model = model self.model = model
......
...@@ -166,6 +166,7 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -166,6 +166,7 @@ class Worker(LocalOrDistributedWorkerBase):
torch.cuda.set_device(self.device) torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype) _check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0] self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment