Unverified Commit a1fc18c0 authored by Aleksandr Malyshev's avatar Aleksandr Malyshev Committed by GitHub
Browse files

[ROCm][AMD][Model] llama 3.2 support upstreaming (#12421)


Signed-off-by: default avatarAleksandr Malyshev <maleksan@amd.com>
Co-authored-by: default avatarAleksandr Malyshev <maleksan@amd.com>
parent 9798b2fb
This diff is collapsed.
......@@ -48,7 +48,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SequenceData
......@@ -847,7 +848,8 @@ class MllamaTextCrossAttention(nn.Module):
i,
i,
)
elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA):
elif self.attn.backend in (_Backend.XFORMERS, _Backend.ROCM_FLASH,
_Backend.TORCH_SDPA):
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
......@@ -859,7 +861,8 @@ class MllamaTextCrossAttention(nn.Module):
raise ValueError(
f"Unsupported Attention backend {self.attn.backend} "
"enum found. Expected the Attention backend to be "
"FLASH_ATTN, FLASH_ATTN_VLLM_V1, XFORMERS or TORCH_SDPA.")
"FLASH_ATTN, FLASH_ATTN_VLLM_V1, "
"XFORMERS or TORCH_SDPA.")
# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
......@@ -1452,6 +1455,13 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
weight_loader(param, loaded_weight, shard_id)
break
else:
orig_name = name
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
logger.debug("Missing name %s, orig name %s", name,
orig_name)
continue
param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment