Unverified Commit d58268c5 authored by Joe Runde's avatar Joe Runde Committed by GitHub
Browse files

[V1] Make v1 more testable (#9888)


Signed-off-by: default avatarJoe Runde <Joseph.Runde@ibm.com>
parent 87bd7e05
...@@ -25,7 +25,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -25,7 +25,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -190,7 +190,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -190,7 +190,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
if hasattr(self.language_model, "sampler"): if hasattr(self.language_model, "sampler"):
return self.language_model.sampler return self.language_model.sampler
return Sampler() return get_sampler()
def forward( def forward(
self, self,
......
...@@ -36,7 +36,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -36,7 +36,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -884,7 +884,7 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -884,7 +884,7 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.transformer.wte.weight self.lm_head.weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors) self.transformer.make_empty_intermediate_tensors)
......
...@@ -39,7 +39,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -39,7 +39,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -444,7 +444,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -444,7 +444,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
prefix, "lm_head")) prefix, "lm_head"))
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
......
...@@ -36,7 +36,7 @@ from vllm.logger import init_logger ...@@ -36,7 +36,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
...@@ -295,7 +295,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -295,7 +295,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.text_config.vocab_size, config.text_config.vocab_size,
logit_scale) logit_scale)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
......
...@@ -44,7 +44,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -44,7 +44,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -393,7 +393,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP): ...@@ -393,7 +393,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
......
...@@ -52,7 +52,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -52,7 +52,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.qwen2 import Qwen2Model
...@@ -990,7 +990,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -990,7 +990,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size)) ["hidden_states", "residual"], config.hidden_size))
......
...@@ -42,7 +42,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -42,7 +42,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale) get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -449,7 +449,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -449,7 +449,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, config.vocab_size,
logit_scale) logit_scale)
self.sampler = Sampler() self.sampler = get_sampler()
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
......
...@@ -34,7 +34,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -34,7 +34,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -261,7 +261,7 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -261,7 +261,7 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
......
...@@ -34,7 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -34,7 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -269,7 +269,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ...@@ -269,7 +269,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
) )
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
......
...@@ -21,7 +21,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, ...@@ -21,7 +21,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs,
...@@ -379,7 +379,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -379,7 +379,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
if hasattr(self.language_model, "sampler"): if hasattr(self.language_model, "sampler"):
return self.language_model.sampler return self.language_model.sampler
return Sampler() return get_sampler()
def _audio_features_to_embeddings( def _audio_features_to_embeddings(
self, input_features: torch.Tensor) -> torch.Tensor: self, input_features: torch.Tensor) -> torch.Tensor:
......
...@@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -334,7 +334,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -334,7 +334,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
......
...@@ -136,7 +136,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -136,7 +136,7 @@ class FlashAttentionImpl(AttentionImpl):
"key/v_scale is not supported in FlashAttention.") "key/v_scale is not supported in FlashAttention.")
output = torch.empty_like(query) output = torch.empty_like(query)
torch.ops.vllm.unified_flash_attention( torch.ops.vllm.unified_v1_flash_attention(
output, output,
query, query,
key, key,
...@@ -156,7 +156,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -156,7 +156,7 @@ class FlashAttentionImpl(AttentionImpl):
return output return output
def unified_flash_attention( def unified_v1_flash_attention(
output: torch.Tensor, output: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
...@@ -222,7 +222,7 @@ def unified_flash_attention( ...@@ -222,7 +222,7 @@ def unified_flash_attention(
output[:num_actual_tokens].copy_(attn_output) output[:num_actual_tokens].copy_(attn_output)
def unified_flash_attention_fake( def unified_v1_flash_attention_fake(
output: torch.Tensor, output: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
...@@ -243,8 +243,8 @@ def unified_flash_attention_fake( ...@@ -243,8 +243,8 @@ def unified_flash_attention_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="unified_flash_attention", op_name="unified_v1_flash_attention",
op_func=unified_flash_attention, op_func=unified_v1_flash_attention,
mutates_args=["kv_cache", "output"], mutates_args=["kv_cache", "output"],
fake_impl=unified_flash_attention_fake, fake_impl=unified_v1_flash_attention_fake,
) )
...@@ -155,6 +155,12 @@ class LLMEngine: ...@@ -155,6 +155,12 @@ class LLMEngine:
# GPU and CPU blocks, which are profiled in the distributed executor. # GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
def __del__(self):
# Small hack- implicit clean up of resources on garbage collect
# TODO: this should probably be explicitly invoked when we're done with
# the engine
self.terminate_detokenizer()
def _initialize_kv_caches(self) -> None: def _initialize_kv_caches(self) -> None:
num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks( num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks(
) )
......
...@@ -73,7 +73,7 @@ class Detokenizer: ...@@ -73,7 +73,7 @@ class Detokenizer:
return None return None
def terminate(self) -> None: def terminate(self) -> None:
self.push_socket.send(b"", flags=zmq.NOBLOCK) self.detokenizer.kill()
self.detokenizer.join() self.detokenizer.join()
...@@ -108,10 +108,10 @@ class DetokenizerProc(multiprocessing.Process): ...@@ -108,10 +108,10 @@ class DetokenizerProc(multiprocessing.Process):
self.push_socket.bind(f"tcp://*:{self.push_port}") self.push_socket.bind(f"tcp://*:{self.push_port}")
while True: while True:
if self.pull_socket.poll(timeout=1000) == 0:
# Nothing to read
continue
message = self.pull_socket.recv() message = self.pull_socket.recv()
if message == b"":
# Terminate signal.
break
inputs = self.msgpack_decoder.decode(message) inputs = self.msgpack_decoder.decode(message)
for req_id in inputs.free_req_ids: for req_id in inputs.free_req_ids:
......
...@@ -2,7 +2,6 @@ import os ...@@ -2,7 +2,6 @@ import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set from typing import TYPE_CHECKING, Dict, List, Optional, Set
from unittest.mock import patch
import numpy as np import numpy as np
import torch import torch
...@@ -26,7 +25,6 @@ from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, ...@@ -26,7 +25,6 @@ from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata) FlashAttentionMetadata)
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.core.scheduler import SchedulerOutput
...@@ -418,7 +416,6 @@ class GPUModelRunner: ...@@ -418,7 +416,6 @@ class GPUModelRunner:
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117 with DeviceMemoryProfiler() as m: # noqa: SIM117
with patch("vllm.model_executor.layers.sampler.Sampler", Sampler):
self.model = get_model(vllm_config=self.vllm_config) self.model = get_model(vllm_config=self.vllm_config)
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment