Commit df81235f authored by zhuwenwen's avatar zhuwenwen
Browse files

[Model] Update Qwen2-Audio model support

parent 3c23ce2d
...@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers, make_empty_intermediate_tensors_factory
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
...@@ -265,6 +265,10 @@ class Qwen2Model(nn.Module): ...@@ -265,6 +265,10 @@ class Qwen2Model(nn.Module):
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else: else:
...@@ -370,6 +374,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -370,6 +374,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.quant_method = None self.quant_method = None
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
......
...@@ -29,6 +29,11 @@ VLLM_TOKEN_ID_ARRAY_TYPE = "l" ...@@ -29,6 +29,11 @@ VLLM_TOKEN_ID_ARRAY_TYPE = "l"
VLLM_INVALID_TOKEN_ID = -1 VLLM_INVALID_TOKEN_ID = -1
def array_full(token_id: int, count: int):
""":class:`array` equivalent of :func:`numpy.full`."""
return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
# We use dataclass for now because it is used for # We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable. # openai server output, and msgspec is not serializable.
# TODO(sang): Fix it. # TODO(sang): Fix it.
...@@ -174,6 +179,26 @@ class SequenceData(msgspec.Struct, ...@@ -174,6 +179,26 @@ class SequenceData(msgspec.Struct,
_first_step_flag: bool = True _first_step_flag: bool = True
@staticmethod
def from_prompt_token_counts(
*token_counts: Tuple[int, int]) -> "SequenceData":
"""
Construct a :class:`SequenceData` instance by concatenating
prompt token sequences.
Each tuple represents one token sequence, expressed in the form
:code:`(token_id, count)`.
"""
if len(token_counts) == 0:
return SequenceData.from_seqs([])
prompt_token_ids_arr = reduce(
array.__iadd__,
(array_full(token_id, count) for token_id, count in token_counts),
)
return SequenceData(prompt_token_ids_arr)
@staticmethod @staticmethod
def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData": def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData":
if len(token_counts) == 0: if len(token_counts) == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment