Commit 0a97d796 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

[Fix] Fix OOM in llava base class (#1249)

parent c411f32e
""" """
Usage: Usage:
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava --chunked-prefill-size=16384 python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava
python3 http_llava_onevision_test.py python3 http_llava_onevision_test.py
""" """
......
...@@ -46,25 +46,7 @@ from sglang.srt.models.mistral import MistralForCausalLM ...@@ -46,25 +46,7 @@ from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
class LlavaLlamaForCausalLM(nn.Module): class LlavaBaseForCausalLM(nn.Module):
def __init__(
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
self.config = config
self.vision_tower = None
self.config.vision_config.hidden_size = config.mm_hidden_size
self.config.text_config.hidden_size = config.hidden_size
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
)
def pad_input_ids( def pad_input_ids(
self, self,
input_ids: List[int], input_ids: List[int],
...@@ -434,14 +416,36 @@ class LlavaLlamaForCausalLM(nn.Module): ...@@ -434,14 +416,36 @@ class LlavaLlamaForCausalLM(nn.Module):
return self.image_size // self.patch_size return self.image_size // self.patch_size
class LlavaQwenForCausalLM(LlavaLlamaForCausalLM): class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
def __init__(
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
self.config = config
self.vision_tower = None
self.config.vision_config.hidden_size = config.mm_hidden_size
self.config.text_config.hidden_size = config.hidden_size
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
)
class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
def __init__( def __init__(
self, self,
config: LlavaConfig, config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
) -> None: ) -> None:
super().__init__(config, quant_config=quant_config, cache_config=cache_config) super().__init__()
self.config = config self.config = config
self.vision_tower = None self.vision_tower = None
if getattr(self.config, "vision_config", None) is None: if getattr(self.config, "vision_config", None) is None:
...@@ -467,14 +471,15 @@ class LlavaQwenForCausalLM(LlavaLlamaForCausalLM): ...@@ -467,14 +471,15 @@ class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
) )
class LlavaMistralForCausalLM(LlavaLlamaForCausalLM): class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
def __init__( def __init__(
self, self,
config: LlavaConfig, config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
) -> None: ) -> None:
super().__init__(config, quant_config=quant_config, cache_config=cache_config) super().__init__()
self.config = config self.config = config
self.vision_tower = None self.vision_tower = None
if getattr(self.config, "vision_config", None) is None: if getattr(self.config, "vision_config", None) is None:
......
...@@ -421,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -421,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if not server_args.disable_flashinfer: if not server_args.disable_flashinfer:
assert_pkg_version( assert_pkg_version(
"flashinfer", "flashinfer",
"0.1.5", "0.1.6",
"Please uninstall the old version and " "Please uninstall the old version and "
"reinstall the latest version by following the instructions " "reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment