Unverified Commit f64eae3a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Fix] Reduce memory usage for loading llava model & Remove EntryClassRemapping (#1308)

parent a5a134f3
name: Pull Request Test name: PR Test
on: on:
push: push:
......
...@@ -205,7 +205,7 @@ It supports streaming, vision, and most features of the Chat/Completions/Models/ ...@@ -205,7 +205,7 @@ It supports streaming, vision, and most features of the Chat/Completions/Models/
``` ```
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --tp 2 python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --tp 2
``` ```
- Add `--dp 2` to enable multi-GPU data parallelism. It can also be used together with tensor parallelism. Data parallelism is better for throughput if there is enough memory. - Add `--dp 2` to enable multi-GPU data parallelism. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total.
``` ```
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --dp 2 --tp 2 python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --dp 2 --tp 2
``` ```
......
# Custom Chat Template in SGLang Runtime # Custom Chat Template in SGLang Runtime
By default, the server uses the chat template specified in the model tokenizer from Hugging Face. It should just work for most official models such as Llama-2/Llama-3. **NOTE**: There are two chat template systems in SGLang project. This document is about setting a custom chat template for the OpenAI-compatible API server (defined at [conversation.py](../../python/sglang/srt/conversation.py)). It is NOT related to the chat template used in the SGLang language frontend (defined at [chat_template.py](../../python/sglang/lang/chat_template.py)).
By default, the server uses the chat template specified in the model tokenizer from Hugging Face.
It should just work for most official models such as Llama-2/Llama-3.
If needed, you can also override the chat template when launching the server: If needed, you can also override the chat template when launching the server:
......
...@@ -2,13 +2,8 @@ ...@@ -2,13 +2,8 @@
Usage: python3 local_example_llava_next.py Usage: python3 local_example_llava_next.py
""" """
from PIL import ImageFile
import sglang as sgl import sglang as sgl
from sglang.lang.chat_template import get_chat_template from sglang.lang.chat_template import get_chat_template
from sglang.srt.utils import load_image
ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images
@sgl.function @sgl.function
......
...@@ -4,7 +4,7 @@ from typing import List, Optional ...@@ -4,7 +4,7 @@ from typing import List, Optional
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
from sglang.lang.interpreter import StreamExecutor from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import ( from sglang.lang.ir import (
...@@ -23,6 +23,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -23,6 +23,7 @@ class RuntimeEndpoint(BaseBackend):
base_url: str, base_url: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,
verify: Optional[str] = None, verify: Optional[str] = None,
chat_template_name: Optional[str] = None,
): ):
super().__init__() super().__init__()
self.support_concate_and_append = True self.support_concate_and_append = True
...@@ -39,9 +40,12 @@ class RuntimeEndpoint(BaseBackend): ...@@ -39,9 +40,12 @@ class RuntimeEndpoint(BaseBackend):
self._assert_success(res) self._assert_success(res)
self.model_info = res.json() self.model_info = res.json()
self.chat_template = get_chat_template_by_model_path( if chat_template_name:
self.model_info["model_path"] self.chat_template = get_chat_template(chat_template_name)
) else:
self.chat_template = get_chat_template_by_model_path(
self.model_info["model_path"]
)
def get_model_name(self): def get_model_name(self):
return self.model_info["model_path"] return self.model_info["model_path"]
......
...@@ -86,8 +86,8 @@ class TokenizerManager: ...@@ -86,8 +86,8 @@ class TokenizerManager:
self.recv_from_detokenizer = context.socket(zmq.PULL) self.recv_from_detokenizer = context.socket(zmq.PULL)
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.send_to_router = context.socket(zmq.PUSH) self.send_to_controller = context.socket(zmq.PUSH)
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}") self.send_to_controller.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
# Read model args # Read model args
self.model_path = server_args.model_path self.model_path = server_args.model_path
...@@ -271,7 +271,7 @@ class TokenizerManager: ...@@ -271,7 +271,7 @@ class TokenizerManager:
input_ids, input_ids,
sampling_params, sampling_params,
) )
self.send_to_router.send_pyobj(tokenized_obj) self.send_to_controller.send_pyobj(tokenized_obj)
# Recv results # Recv results
event = asyncio.Event() event = asyncio.Event()
...@@ -367,7 +367,7 @@ class TokenizerManager: ...@@ -367,7 +367,7 @@ class TokenizerManager:
input_ids, input_ids,
sampling_params, sampling_params,
) )
self.send_to_router.send_pyobj(tokenized_obj) self.send_to_controller.send_pyobj(tokenized_obj)
event = asyncio.Event() event = asyncio.Event()
state = ReqState([], False, event) state = ReqState([], False, event)
...@@ -500,14 +500,14 @@ class TokenizerManager: ...@@ -500,14 +500,14 @@ class TokenizerManager:
def flush_cache(self): def flush_cache(self):
req = FlushCacheReq() req = FlushCacheReq()
self.send_to_router.send_pyobj(req) self.send_to_controller.send_pyobj(req)
def abort_request(self, rid: str): def abort_request(self, rid: str):
if rid not in self.rid_to_state: if rid not in self.rid_to_state:
return return
del self.rid_to_state[rid] del self.rid_to_state[rid]
req = AbortReq(rid) req = AbortReq(rid)
self.send_to_router.send_pyobj(req) self.send_to_controller.send_pyobj(req)
async def update_weights( async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
...@@ -524,7 +524,7 @@ class TokenizerManager: ...@@ -524,7 +524,7 @@ class TokenizerManager:
# wait for the previous generation requests to finish # wait for the previous generation requests to finish
while len(self.rid_to_state) > 0: while len(self.rid_to_state) > 0:
await asyncio.sleep(0) await asyncio.sleep(0)
self.send_to_router.send_pyobj(obj) self.send_to_controller.send_pyobj(obj)
self.model_update_result = asyncio.Future() self.model_update_result = asyncio.Future()
result = await self.model_update_result result = await self.model_update_result
if result.success: if result.success:
......
...@@ -606,16 +606,6 @@ def import_model_classes(): ...@@ -606,16 +606,6 @@ def import_model_classes():
assert entry.__name__ not in model_arch_name_to_cls assert entry.__name__ not in model_arch_name_to_cls
model_arch_name_to_cls[entry.__name__] = entry model_arch_name_to_cls[entry.__name__] = entry
# compat: some models such as chatglm has incorrect class set in config.json
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
if hasattr(module, "EntryClassRemapping") and isinstance(
module.EntryClassRemapping, list
):
for remap in module.EntryClassRemapping:
if isinstance(remap, tuple) and len(remap) == 2:
assert remap[0] not in model_arch_name_to_cls
model_arch_name_to_cls[remap[0]] = remap[1]
return model_arch_name_to_cls return model_arch_name_to_cls
......
...@@ -402,6 +402,8 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -402,6 +402,8 @@ class ChatGLMForCausalLM(nn.Module):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
EntryClass = ChatGLMForCausalLM class ChatGLMModel(ChatGLMForCausalLM):
# compat: glm model.config class == ChatGLMModel pass
EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)]
EntryClass = [ChatGLMForCausalLM, ChatGLMModel]
...@@ -297,7 +297,6 @@ class ExaoneForCausalLM(nn.Module): ...@@ -297,7 +297,6 @@ class ExaoneForCausalLM(nn.Module):
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
efficient_weight_load=False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -345,9 +344,7 @@ class ExaoneForCausalLM(nn.Module): ...@@ -345,9 +344,7 @@ class ExaoneForCausalLM(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
return len(params_dict) return len(params_dict)
def load_weights( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -358,7 +355,7 @@ class ExaoneForCausalLM(nn.Module): ...@@ -358,7 +355,7 @@ class ExaoneForCausalLM(nn.Module):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
def load_weights_per_param(name, loaded_weight): for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name: if "rotary_emb.inv_freq" in name or "projector" in name:
return return
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
...@@ -368,6 +365,7 @@ class ExaoneForCausalLM(nn.Module): ...@@ -368,6 +365,7 @@ class ExaoneForCausalLM(nn.Module):
if name.startswith("model.vision_tower") and name not in params_dict: if name.startswith("model.vision_tower") and name not in params_dict:
return return
name = name.replace("attn.attention", "self_attn")
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
...@@ -387,13 +385,5 @@ class ExaoneForCausalLM(nn.Module): ...@@ -387,13 +385,5 @@ class ExaoneForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if name is None or loaded_weight is None:
for name, loaded_weight in weights:
name = name.replace("attn.attention", "self_attn")
load_weights_per_param(name, loaded_weight)
else:
name = name.replace("attn.attention", "self_attn")
load_weights_per_param(name, loaded_weight)
EntryClass = ExaoneForCausalLM EntryClass = ExaoneForCausalLM
...@@ -295,7 +295,6 @@ class LlamaForCausalLM(nn.Module): ...@@ -295,7 +295,6 @@ class LlamaForCausalLM(nn.Module):
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
efficient_weight_load=False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -305,6 +304,8 @@ class LlamaForCausalLM(nn.Module): ...@@ -305,6 +304,8 @@ class LlamaForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler() self.sampler = Sampler()
self.param_dict = dict(self.named_parameters())
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
...@@ -320,30 +321,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -320,30 +321,7 @@ class LlamaForCausalLM(nn.Module):
sample_output = self.sampler(logits_output, input_metadata.sampling_info) sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output return sample_output, logits_output
def get_module_name(self, name): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id, num_shard)
("qkv_proj", "q_proj", "q", 3),
("qkv_proj", "k_proj", "k", 3),
("qkv_proj", "v_proj", "v", 3),
("gate_up_proj", "gate_proj", 0, 2),
("gate_up_proj", "up_proj", 1, 2),
]
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
if weight_name in name:
return (
name.replace(weight_name, param_name)[: -len(".weight")],
num_shard,
)
return name[: -len(".weight")], 1
def get_num_params(self):
params_dict = dict(self.named_parameters())
return len(params_dict)
def load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -352,9 +330,9 @@ class LlamaForCausalLM(nn.Module): ...@@ -352,9 +330,9 @@ class LlamaForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = self.param_dict
def load_weights_per_param(name, loaded_weight): for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name: if "rotary_emb.inv_freq" in name or "projector" in name:
return return
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
...@@ -383,11 +361,5 @@ class LlamaForCausalLM(nn.Module): ...@@ -383,11 +361,5 @@ class LlamaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if name is None or loaded_weight is None:
for name, loaded_weight in weights:
load_weights_per_param(name, loaded_weight)
else:
load_weights_per_param(name, loaded_weight)
EntryClass = LlamaForCausalLM EntryClass = LlamaForCausalLM
...@@ -16,17 +16,16 @@ limitations under the License. ...@@ -16,17 +16,16 @@ limitations under the License.
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
import tqdm
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.models.llama2 import LlamaModel from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
class LlamaForClassification(nn.Module): class LlamaForClassification(nn.Module):
...@@ -42,10 +41,12 @@ class LlamaForClassification(nn.Module): ...@@ -42,10 +41,12 @@ class LlamaForClassification(nn.Module):
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(config, quant_config=quant_config)
self.classification_head = nn.Linear( self.classification_head = nn.Linear(
config.hidden_size, config.classification_out_size config.hidden_size, config.classification_out_size, bias=False
) )
self.eos_token_id = config.eos_token_id self.eos_token_id = config.eos_token_id
self.param_dict = dict(self.named_parameters())
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
...@@ -65,7 +66,7 @@ class LlamaForClassification(nn.Module): ...@@ -65,7 +66,7 @@ class LlamaForClassification(nn.Module):
(input_metadata.batch_size, self.config.classification_out_size) (input_metadata.batch_size, self.config.classification_out_size)
).to(input_ids.device) ).to(input_ids.device)
return LogitsProcessorOutput( logits_output = LogitsProcessorOutput(
next_token_logits=scores, next_token_logits=scores,
next_token_logprobs=scores, next_token_logprobs=scores,
normalized_prompt_logprobs=scores, normalized_prompt_logprobs=scores,
...@@ -74,46 +75,38 @@ class LlamaForClassification(nn.Module): ...@@ -74,46 +75,38 @@ class LlamaForClassification(nn.Module):
output_top_logprobs=None, output_top_logprobs=None,
) )
# A dummy to make this work
sample_output = SampleOutput(
success=torch.full(
size=(scores.shape[0],),
fill_value=True,
dtype=torch.bool,
),
probs=torch.full(
size=(scores.shape[0], 1),
fill_value=1.0,
dtype=torch.float16,
),
batch_next_token_ids=torch.full(
size=(scores.shape[0],),
fill_value=0,
dtype=torch.long,
),
)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ params_dict = self.param_dict
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
if get_tensor_model_parallel_rank() == 0:
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if "lm_head" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for name, loaded_weight in weights:
if weight_name not in name: if "classification_head" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
elif "lm_head" in name:
continue
else:
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
EntryClass = LlamaForClassification EntryClass = LlamaForClassification
from typing import Iterable, Optional, Tuple from typing import Iterable, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -7,7 +7,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -7,7 +7,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel from sglang.srt.models.llama import LlamaModel
class LlamaEmbeddingModel(nn.Module): class LlamaEmbeddingModel(nn.Module):
...@@ -16,7 +16,6 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -16,7 +16,6 @@ class LlamaEmbeddingModel(nn.Module):
config: LlamaConfig, config: LlamaConfig,
quant_config=None, quant_config=None,
cache_config=None, cache_config=None,
efficient_weight_load=False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(config, quant_config=quant_config)
...@@ -86,6 +85,8 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -86,6 +85,8 @@ class LlamaEmbeddingModel(nn.Module):
load_weights_per_param(name, loaded_weight) load_weights_per_param(name, loaded_weight)
EntryClass = LlamaEmbeddingModel class MistralModel(LlamaEmbeddingModel):
# compat: e5-mistral model.config class == MistralModel pass
EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]
EntryClass = [LlamaEmbeddingModel, MistralModel]
...@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import ( ...@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import (
unpad_image_shape, unpad_image_shape,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
...@@ -395,21 +395,19 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -395,21 +395,19 @@ class LlavaBaseForCausalLM(nn.Module):
"model.mm_projector.0": "multi_modal_projector.linear_1", "model.mm_projector.0": "multi_modal_projector.linear_1",
"model.mm_projector.2": "multi_modal_projector.linear_2", "model.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.image_newline": "language_model.model.image_newline",
} }
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
weights = list(weights)
for name, loaded_weight in weights: for name, loaded_weight in weights:
# FIXME: why projector weights read two times? if "projector" in name or "vision_tower" in name or "image_newline" in name:
if "projector" in name or "vision_tower" in name:
for weight_name, param_name in projector_weights.items(): for weight_name, param_name in projector_weights.items():
if weight_name in name: if weight_name in name:
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
else:
# load language model self.language_model.load_weights([(name, loaded_weight)])
self.language_model.load_weights(weights)
@property @property
def num_patches_per_side(self): def num_patches_per_side(self):
...@@ -429,6 +427,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM): ...@@ -429,6 +427,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
self.vision_tower = None self.vision_tower = None
self.config.vision_config.hidden_size = config.mm_hidden_size self.config.vision_config.hidden_size = config.mm_hidden_size
self.config.text_config.hidden_size = config.hidden_size self.config.text_config.hidden_size = config.hidden_size
self.multi_modal_projector = LlavaMultiModalProjector(config) self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = LlamaForCausalLM(config, quant_config=quant_config) self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
if "unpad" in getattr(config, "mm_patch_merge_type", ""): if "unpad" in getattr(config, "mm_patch_merge_type", ""):
...@@ -448,9 +447,9 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM): ...@@ -448,9 +447,9 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
self.config = config self.config = config
self.vision_tower = None self.vision_tower = None
if getattr(self.config, "vision_config", None) is None: if getattr(self.config, "vision_config", None) is None:
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower) self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
if getattr(self.config, "text_config", None) is None: if getattr(self.config, "text_config", None) is None:
self.config.text_config = Qwen2Config(self.config._name_or_path) self.config.text_config = Qwen2Config(self.config._name_or_path)
...@@ -459,7 +458,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM): ...@@ -459,7 +458,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
if getattr(self.config, "projector_hidden_act", None) is None: if getattr(self.config, "projector_hidden_act", None) is None:
self.config.projector_hidden_act = "gelu" self.config.projector_hidden_act = "gelu"
if getattr(self.config, "image_token_index", None) is None: if getattr(self.config, "image_token_index", None) is None:
self.config.image_token_index = 151646 self.config.image_token_index = 151646
...@@ -482,9 +480,9 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM): ...@@ -482,9 +480,9 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
self.config = config self.config = config
self.vision_tower = None self.vision_tower = None
if getattr(self.config, "vision_config", None) is None: if getattr(self.config, "vision_config", None) is None:
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower) self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
if getattr(self.config, "text_config", None) is None: if getattr(self.config, "text_config", None) is None:
self.config.text_config = MistralConfig(self.config._name_or_path) self.config.text_config = MistralConfig(self.config._name_or_path)
...@@ -493,7 +491,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM): ...@@ -493,7 +491,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
if getattr(self.config, "projector_hidden_act", None) is None: if getattr(self.config, "projector_hidden_act", None) is None:
self.config.projector_hidden_act = "gelu" self.config.projector_hidden_act = "gelu"
if getattr(self.config, "image_token_index", None) is None: if getattr(self.config, "image_token_index", None) is None:
self.config.image_token_index = 32000 self.config.image_token_index = 32000
......
...@@ -27,7 +27,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf ...@@ -27,7 +27,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
class LlavaVidForCausalLM(nn.Module): class LlavaVidForCausalLM(nn.Module):
...@@ -239,12 +239,12 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -239,12 +239,12 @@ class LlavaVidForCausalLM(nn.Module):
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1", "model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2", "model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.image_newline": "language_model.model.image_newline",
} }
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
weights = list(weights)
for name, loaded_weight in weights: for name, loaded_weight in weights:
# FIXME: why projector weights read two times? # FIXME: why projector weights read two times?
if "projector" in name or "vision_tower" in name: if "projector" in name or "vision_tower" in name or "image_newline" in name:
for weight_name, param_name in projector_weights.items(): for weight_name, param_name in projector_weights.items():
if weight_name in name: if weight_name in name:
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
...@@ -255,9 +255,8 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -255,9 +255,8 @@ class LlavaVidForCausalLM(nn.Module):
continue continue
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
else:
# load language model self.language_model.load_weights([(name, loaded_weight)])
self.language_model.load_weights(weights)
@property @property
def num_patches_per_side(self): def num_patches_per_side(self):
......
...@@ -15,12 +15,11 @@ limitations under the License. ...@@ -15,12 +15,11 @@ limitations under the License.
"""Inference-only Mistral model.""" """Inference-only Mistral model."""
from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
class MistralForCausalLM(LlamaForCausalLM): class MistralForCausalLM(LlamaForCausalLM):
def __init__(self, *args, **kwargs): pass
super().__init__(*args, **kwargs)
EntryClass = MistralForCausalLM EntryClass = MistralForCausalLM
""" """
Usage: Usage:
python3 -m sglang.launch_server --model-path /model/llama-classification python3 -m sglang.launch_server --disable-cuda-graph --model-path /model/llama-classification
python3 test_httpserver_classify.py python3 test_httpserver_classify.py
""" """
......
...@@ -3,23 +3,24 @@ Usage: ...@@ -3,23 +3,24 @@ Usage:
python3 reference_hf.py --model TinyLlama/TinyLlama-1.1B-Chat-v0.4 python3 reference_hf.py --model TinyLlama/TinyLlama-1.1B-Chat-v0.4
Reference output: Reference output:
========== Prompt 0 ==========
prefill logits (final) tensor([-8.3125, -7.1172, 3.3398, ..., -4.9531, -4.1328, -3.4141],
device='cuda:0')
<s> The capital of France is Paris. <s> The capital of France is Paris.
The capital of the United States is Washington, D.C. The capital of the United States is Washington, D.C.
The capital of Canada is Ottawa.
The capital of Japan is Tokyo ========== Prompt 1 ==========
prefill logits tensor([-8.3125, -7.1172, 3.3398, ..., -4.9570, -4.1328, -3.4141], prefill logits (final) tensor([-8.9062, -9.0156, 4.1484, ..., -4.9922, -4.4961, -4.0742],
device='cuda:0') device='cuda:0')
<s> The capital of the United Kindom is London. <s> The capital of the United Kindom is London.
The capital of the United Kingdom is London. The capital of the United Kingdom is London.
The capital of the United Kingdom is London. The capital of
The capital of the United Kingdom is London.
prefill logits tensor([-8.9062, -9.0156, 4.1406, ..., -4.9922, -4.4961, -4.0742], ========== Prompt 2 ==========
prefill logits (final) tensor([-9.6328, -9.0547, 4.0234, ..., -5.3047, -4.7148, -4.4609],
device='cuda:0') device='cuda:0')
<s> Today is a sunny day and I like to go for a walk in the park. <s> Today is a sunny day and I like to go for a walk in the park.
I'm going to the park to play in the grass and water. I'm going to the
Today is a very
prefill logits tensor([-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4609],
device='cuda:0')
""" """
import argparse import argparse
...@@ -47,7 +48,7 @@ def normal_text(args): ...@@ -47,7 +48,7 @@ def normal_text(args):
] ]
max_new_tokens = 16 max_new_tokens = 16
for p in prompts: for i, p in enumerate(prompts):
if isinstance(p, str): if isinstance(p, str):
input_ids = t.encode(p, return_tensors="pt").cuda() input_ids = t.encode(p, return_tensors="pt").cuda()
else: else:
...@@ -60,7 +61,8 @@ def normal_text(args): ...@@ -60,7 +61,8 @@ def normal_text(args):
prefill_logits = m.forward(input_ids).logits[0][-1] prefill_logits = m.forward(input_ids).logits[0][-1]
print("prefill logits", prefill_logits) print(f"\n========== Prompt {i} ==========")
print("prefill logits (final)", prefill_logits)
print(output_str) print(output_str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment