Unverified Commit 969bbc7c authored by Zhonghua Deng's avatar Zhonghua Deng Committed by GitHub
Browse files

[Model] Add MiMo-V2-Flash support (#30836)


Signed-off-by: default avatarAbatom <abzhonghua@gmail.com>
Signed-off-by: default avatarJumiar <liuanqim10@126.com>
Signed-off-by: default avatarZyann7 <zyann7@outlook.com>
Co-authored-by: default avatarJumiar <liuanqim10@126.com>
Co-authored-by: default avatarZyann7 <zyann7@outlook.com>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 268a972c
...@@ -415,6 +415,7 @@ th { ...@@ -415,6 +415,7 @@ th {
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | | `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ |
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | | `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ |
| `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ | | `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ |
| `MiMoV2FlashForCausalLM` | MiMoV2Flash | `XiaomiMiMo/MiMo-V2-Flash`, etc. | ︎| ✅︎ |
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | | `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ |
| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | | `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ |
| `MiniMaxM2ForCausalLM` | MiniMax-M2 |`MiniMaxAI/MiniMax-M2`, etc. | | ✅︎ | | `MiniMaxM2ForCausalLM` | MiniMax-M2 |`MiniMaxAI/MiniMax-M2`, etc. | | ✅︎ |
......
...@@ -459,6 +459,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -459,6 +459,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
), ),
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"), "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"),
"MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True), "MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True),
"MiMoV2FlashForCausalLM": _HfExamplesInfo(
"XiaomiMiMo/MiMo-V2-Flash", trust_remote_code=True
),
"Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst"), "Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst"),
} }
......
...@@ -18,6 +18,7 @@ from vllm.config.lora import LoRAConfig ...@@ -18,6 +18,7 @@ from vllm.config.lora import LoRAConfig
from vllm.config.model import ( from vllm.config.model import (
ModelConfig, ModelConfig,
iter_architecture_defaults, iter_architecture_defaults,
str_dtype_to_torch_dtype,
try_match_architecture_defaults, try_match_architecture_defaults,
) )
from vllm.config.multimodal import MultiModalConfig from vllm.config.multimodal import MultiModalConfig
...@@ -72,6 +73,7 @@ __all__ = [ ...@@ -72,6 +73,7 @@ __all__ = [
# From vllm.config.model # From vllm.config.model
"ModelConfig", "ModelConfig",
"iter_architecture_defaults", "iter_architecture_defaults",
"str_dtype_to_torch_dtype",
"try_match_architecture_defaults", "try_match_architecture_defaults",
# From vllm.config.multimodal # From vllm.config.multimodal
"MultiModalConfig", "MultiModalConfig",
......
...@@ -1849,6 +1849,11 @@ _STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -1849,6 +1849,11 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
} }
def str_dtype_to_torch_dtype(type: str):
return _STR_DTYPE_TO_TORCH_DTYPE.get(type)
# model_type -> reason # model_type -> reason
_FLOAT16_NOT_SUPPORTED_MODELS = { _FLOAT16_NOT_SUPPORTED_MODELS = {
"gemma2": "Numerical instability. Please use bfloat16 or float32 instead.", "gemma2": "Numerical instability. Please use bfloat16 or float32 instead.",
......
...@@ -277,6 +277,7 @@ class LinearBase(CustomOp): ...@@ -277,6 +277,7 @@ class LinearBase(CustomOp):
self.params_dtype = params_dtype self.params_dtype = params_dtype
self.quant_config = quant_config self.quant_config = quant_config
self.prefix = prefix self.prefix = prefix
self.allow_fp8_block_shape_mismatch = False
if quant_config is None: if quant_config is None:
self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod() self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
else: else:
...@@ -475,6 +476,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -475,6 +476,7 @@ class ColumnParallelLinear(LinearBase):
disable_tp=disable_tp, disable_tp=disable_tp,
) )
self._maybe_allow_fp8_block_shape_mismatch()
self.gather_output = gather_output self.gather_output = gather_output
if output_sizes is None: if output_sizes is None:
...@@ -509,6 +511,33 @@ class ColumnParallelLinear(LinearBase): ...@@ -509,6 +511,33 @@ class ColumnParallelLinear(LinearBase):
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.update_param_tp_status() self.update_param_tp_status()
def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
quant_config = getattr(self, "quant_config", None)
weight_block = getattr(quant_config, "weight_block_size", None)
if (
weight_block is None
or len(weight_block) < 1
or len(self.output_partition_sizes) <= 1
):
return
try:
block_n = int(weight_block[0])
except (ValueError, TypeError):
return
if block_n <= 0:
return
if any(size % block_n != 0 for size in self.output_partition_sizes):
self.allow_fp8_block_shape_mismatch = True
logger.debug(
"Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
getattr(self, "prefix", "<unknown>"),
block_n,
self.output_partition_sizes,
)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
...@@ -906,9 +935,11 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -906,9 +935,11 @@ class QKVParallelLinear(ColumnParallelLinear):
*, *,
return_bias: bool = True, return_bias: bool = True,
disable_tp: bool = False, disable_tp: bool = False,
v_head_size: int | None = None,
): ):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = head_size self.head_size = head_size
self.v_head_size = v_head_size if v_head_size is not None else head_size
self.total_num_heads = total_num_heads self.total_num_heads = total_num_heads
if total_num_kv_heads is None: if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads total_num_kv_heads = total_num_heads
...@@ -924,12 +955,14 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -924,12 +955,14 @@ class QKVParallelLinear(ColumnParallelLinear):
self.num_kv_head_replicas = 1 self.num_kv_head_replicas = 1
input_size = self.hidden_size input_size = self.hidden_size
output_size = ( output_size = (
(self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size self.num_heads * self.head_size
) + self.num_kv_heads * self.head_size
+ self.num_kv_heads * self.v_head_size
) * tp_size
self.output_sizes = [ self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj self.num_kv_heads * self.v_head_size * tp_size, # v_proj
] ]
super().__init__( super().__init__(
...@@ -950,7 +983,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -950,7 +983,8 @@ class QKVParallelLinear(ColumnParallelLinear):
"q": 0, "q": 0,
"k": self.num_heads * self.head_size, "k": self.num_heads * self.head_size,
"v": (self.num_heads + self.num_kv_heads) * self.head_size, "v": (self.num_heads + self.num_kv_heads) * self.head_size,
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size, "total": (self.num_heads + self.num_kv_heads) * self.head_size
+ self.num_kv_heads * self.v_head_size,
} }
return shard_offset_mapping.get(loaded_shard_id) return shard_offset_mapping.get(loaded_shard_id)
...@@ -958,7 +992,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -958,7 +992,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size_mapping = { shard_size_mapping = {
"q": self.num_heads * self.head_size, "q": self.num_heads * self.head_size,
"k": self.num_kv_heads * self.head_size, "k": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.head_size, "v": self.num_kv_heads * self.v_head_size,
} }
return shard_size_mapping.get(loaded_shard_id) return shard_size_mapping.get(loaded_shard_id)
...@@ -985,7 +1019,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -985,7 +1019,7 @@ class QKVParallelLinear(ColumnParallelLinear):
( (
"v", "v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size, (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size, self.total_num_kv_heads * self.v_head_size,
), ),
] ]
...@@ -1110,7 +1144,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1110,7 +1144,7 @@ class QKVParallelLinear(ColumnParallelLinear):
( (
"v", "v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size, (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size, self.total_num_kv_heads * self.v_head_size,
), ),
] ]
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
...@@ -1139,11 +1173,12 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1139,11 +1173,12 @@ class QKVParallelLinear(ColumnParallelLinear):
"v": ( "v": (
(self.total_num_heads + self.total_num_kv_heads) (self.total_num_heads + self.total_num_kv_heads)
* self.head_size, * self.head_size,
self.total_num_kv_heads * self.head_size, self.total_num_kv_heads * self.v_head_size,
), ),
"total": ( "total": (
(self.total_num_heads + 2 * self.total_num_kv_heads) (self.total_num_heads + self.total_num_kv_heads)
* self.head_size, * self.head_size
+ self.total_num_kv_heads * self.v_head_size,
0, 0,
), ),
} }
...@@ -1170,7 +1205,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1170,7 +1205,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size = self.num_kv_heads * self.head_size shard_size = self.num_kv_heads * self.head_size
elif loaded_shard_id == "v": elif loaded_shard_id == "v":
shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size shard_size = self.num_kv_heads * self.v_head_size
# Special case for Quantized Weights. # Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account # If quantized, we need to adjust the offset and size to account
# for the packing. # for the packing.
...@@ -1199,10 +1234,11 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1199,10 +1234,11 @@ class QKVParallelLinear(ColumnParallelLinear):
), ),
"v": ( "v": (
(self.num_heads + self.num_kv_heads) * self.head_size, (self.num_heads + self.num_kv_heads) * self.head_size,
self.num_kv_heads * self.head_size, self.num_kv_heads * self.v_head_size,
), ),
"total": ( "total": (
(self.num_heads + 2 * self.num_kv_heads) * self.head_size, (self.num_heads + self.num_kv_heads) * self.head_size
+ self.num_kv_heads * self.v_head_size,
0, 0,
), ),
} }
......
...@@ -1252,6 +1252,14 @@ def validate_fp8_block_shape( ...@@ -1252,6 +1252,14 @@ def validate_fp8_block_shape(
"""Validate block quantization shapes for tensor parallelism.""" """Validate block quantization shapes for tensor parallelism."""
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
if getattr(layer, "allow_fp8_block_shape_mismatch", False):
logger.debug(
"Skipping FP8 block shape validation for layer %s due to detected"
" mismatch allowance.",
getattr(layer, "prefix", "<unknown>"),
)
return
tp_size = getattr(layer, "tp_size", get_tensor_model_parallel_world_size()) tp_size = getattr(layer, "tp_size", get_tensor_model_parallel_world_size())
block_n, block_k = block_size[0], block_size[1] block_n, block_k = block_size[0], block_size[1]
......
This diff is collapsed.
...@@ -152,6 +152,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -152,6 +152,7 @@ _TEXT_GENERATION_MODELS = {
"MptForCausalLM": ("mpt", "MPTForCausalLM"), "MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiMoForCausalLM": ("mimo", "MiMoForCausalLM"), "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
"MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"), "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment