Unverified Commit 27864a85 authored by zhima771's avatar zhima771 Committed by GitHub
Browse files

feat: support LoRA for DeepSeek-OCR(Language Model part) (#31569)


Signed-off-by: default avatarzhima771 <15836938703@163.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 5cc48766
...@@ -677,7 +677,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen ...@@ -677,7 +677,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | | `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ |
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | | `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ |
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | | `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ |
| `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I<sup>+</sup> | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ | | `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I<sup>+</sup> | `deepseek-ai/DeepSeek-OCR`, etc. | ✅︎ | ✅︎ |
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | | `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>E+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>E+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
......
...@@ -14,9 +14,11 @@ from vllm.config import VllmConfig ...@@ -14,9 +14,11 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.models.interfaces import ( from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
) )
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper, WeightsMapper,
...@@ -343,7 +345,7 @@ class DeepseekOCRMultiModalProcessor( ...@@ -343,7 +345,7 @@ class DeepseekOCRMultiModalProcessor(
info=DeepseekOCRProcessingInfo, info=DeepseekOCRProcessingInfo,
dummy_inputs=DeepseekOCRDummyInputsBuilder, dummy_inputs=DeepseekOCRDummyInputsBuilder,
) )
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
# map prefix for language backbone # map prefix for language backbone
...@@ -589,3 +591,13 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -589,3 +591,13 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return autoloaded_weights return autoloaded_weights
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="projector",
tower_model=["sam_model", "vision_model"],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment