Commit df81235f authored by zhuwenwen's avatar zhuwenwen
Browse files

[Model] Update Qwen2-Audio model support

parent 3c23ce2d
......@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers, make_empty_intermediate_tensors_factory
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
......@@ -265,6 +265,10 @@ class Qwen2Model(nn.Module):
prefix=f"{prefix}.layers",
)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
......@@ -370,6 +374,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
......
......@@ -29,6 +29,11 @@ VLLM_TOKEN_ID_ARRAY_TYPE = "l"
VLLM_INVALID_TOKEN_ID = -1
def array_full(token_id: int, count: int):
""":class:`array` equivalent of :func:`numpy.full`."""
return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
......@@ -174,6 +179,26 @@ class SequenceData(msgspec.Struct,
_first_step_flag: bool = True
@staticmethod
def from_prompt_token_counts(
*token_counts: Tuple[int, int]) -> "SequenceData":
"""
Construct a :class:`SequenceData` instance by concatenating
prompt token sequences.
Each tuple represents one token sequence, expressed in the form
:code:`(token_id, count)`.
"""
if len(token_counts) == 0:
return SequenceData.from_seqs([])
prompt_token_ids_arr = reduce(
array.__iadd__,
(array_full(token_id, count) for token_id, count in token_counts),
)
return SequenceData(prompt_token_ids_arr)
@staticmethod
def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData":
if len(token_counts) == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment