"vllm/vscode:/vscode.git/clone" did not exist on "e439c784fa318dbc23c04b0730bee0fccf46481d"
Unverified Commit e14467be authored by Divakar Verma's avatar Divakar Verma Committed by GitHub
Browse files

[bugfix] Aria model (#32727)


Signed-off-by: default avatarDivakar Verma <divakar.verma@amd.com>
parent 7727ce35
......@@ -15,7 +15,9 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
......@@ -554,6 +556,18 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
prefix=maybe_prefix(prefix, "language_model.model"),
)
self.lm_head = ParallelLMHead(
config.text_config.vocab_size,
config.text_config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
config.text_config.vocab_size, scale=logit_scale
)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> AriaImagePixelInputs | None:
......@@ -637,7 +651,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment