Unverified Commit e14467be authored by Divakar Verma's avatar Divakar Verma Committed by GitHub
Browse files

[bugfix] Aria model (#32727)


Signed-off-by: default avatarDivakar Verma <divakar.verma@amd.com>
parent 7727ce35
...@@ -15,7 +15,9 @@ from vllm.distributed import get_tensor_model_parallel_rank ...@@ -15,7 +15,9 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
...@@ -554,6 +556,18 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -554,6 +556,18 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
prefix=maybe_prefix(prefix, "language_model.model"), prefix=maybe_prefix(prefix, "language_model.model"),
) )
self.lm_head = ParallelLMHead(
config.text_config.vocab_size,
config.text_config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
config.text_config.vocab_size, scale=logit_scale
)
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
) -> AriaImagePixelInputs | None: ) -> AriaImagePixelInputs | None:
...@@ -637,7 +651,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -637,7 +651,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states) logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment