Unverified Commit 4936be8a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Revert "Revert "[FEAT] Support GGUF format"" (#2287)

parent 1bfa511b
...@@ -585,12 +585,10 @@ class MiniCPM3ForCausalLM(nn.Module): ...@@ -585,12 +585,10 @@ class MiniCPM3ForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight lm_head = self.model.embed_tokens
else: else:
lm_head_weight = self.lm_head.weight lm_head = self.lm_head
return self.logits_processor( return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
input_ids, hidden_states, lm_head_weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -310,7 +310,7 @@ class MixtralForCausalLM(nn.Module): ...@@ -310,7 +310,7 @@ class MixtralForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -343,7 +343,7 @@ class QuantMixtralForCausalLM(nn.Module): ...@@ -343,7 +343,7 @@ class QuantMixtralForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module):
skip_cross_attention=skip_cross_attention, skip_cross_attention=skip_cross_attention,
) )
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch input_ids, hidden_states, self.language_model.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -306,7 +306,7 @@ class OlmoForCausalLM(nn.Module): ...@@ -306,7 +306,7 @@ class OlmoForCausalLM(nn.Module):
input_embeds=input_embeds, input_embeds=input_embeds,
) )
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
...@@ -326,11 +326,6 @@ class OlmoForCausalLM(nn.Module): ...@@ -326,11 +326,6 @@ class OlmoForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -321,7 +321,7 @@ class OlmoeForCausalLM(nn.Module): ...@@ -321,7 +321,7 @@ class OlmoeForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -397,10 +397,13 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -397,10 +397,13 @@ class Phi3SmallForCausalLM(nn.Module):
def compute_logits( def compute_logits(
self, self,
input_ids: torch.LongTensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata, sampling_metadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) logits = self.logits_processor(
input_ids, self.lm_head, hidden_states, sampling_metadata
)
if self.dummy_token_indices is not None and logits is not None: if self.dummy_token_indices is not None and logits is not None:
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
return logits return logits
...@@ -422,7 +425,7 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -422,7 +425,7 @@ class Phi3SmallForCausalLM(nn.Module):
if not get_embedding: if not get_embedding:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
else: else:
......
...@@ -260,7 +260,7 @@ class QWenLMHeadModel(nn.Module): ...@@ -260,7 +260,7 @@ class QWenLMHeadModel(nn.Module):
): ):
hidden_states = self.transformer(input_ids, positions, forward_batch) hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -230,6 +230,7 @@ class Qwen2Model(nn.Module): ...@@ -230,6 +230,7 @@ class Qwen2Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config,
) )
self.layers = make_layers( self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
...@@ -276,7 +277,12 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -276,7 +277,12 @@ class Qwen2ForCausalLM(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2Model(config, quant_config=quant_config) self.model = Qwen2Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
...@@ -292,7 +298,7 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -292,7 +298,7 @@ class Qwen2ForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if not get_embedding: if not get_embedding:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
else: else:
return self.pooler(hidden_states, forward_batch) return self.pooler(hidden_states, forward_batch)
...@@ -306,6 +312,7 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -306,6 +312,7 @@ class Qwen2ForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name: if "rotary_emb.inv_freq" in name or "projector" in name:
...@@ -335,11 +342,6 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -335,11 +342,6 @@ class Qwen2ForCausalLM(nn.Module):
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if (
self.config.tie_word_embeddings
and name == "model.embed_tokens.weight"
):
weight_loader(params_dict["lm_head.weight"], loaded_weight)
EntryClass = Qwen2ForCausalLM EntryClass = Qwen2ForCausalLM
...@@ -376,7 +376,7 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -376,7 +376,7 @@ class Qwen2MoeForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -668,7 +668,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -668,7 +668,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
if not get_embedding: if not get_embedding:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
else: else:
return self.pooler(hidden_states, forward_batch) return self.pooler(hidden_states, forward_batch)
...@@ -686,8 +686,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -686,8 +686,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -261,7 +261,7 @@ class StableLmForCausalLM(nn.Module): ...@@ -261,7 +261,7 @@ class StableLmForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -396,6 +396,9 @@ class TorchNativeLlamaForCausalLM(nn.Module): ...@@ -396,6 +396,9 @@ class TorchNativeLlamaForCausalLM(nn.Module):
self.torchao_config = global_server_args_dict["torchao_config"] self.torchao_config = global_server_args_dict["torchao_config"]
self.supports_torch_tp = True self.supports_torch_tp = True
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(config, quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
...@@ -413,7 +416,7 @@ class TorchNativeLlamaForCausalLM(nn.Module): ...@@ -413,7 +416,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
) -> LogitsProcessorOutput: ) -> LogitsProcessorOutput:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def get_hidden_dim(self, module_name): def get_hidden_dim(self, module_name):
...@@ -501,14 +504,6 @@ class TorchNativeLlamaForCausalLM(nn.Module): ...@@ -501,14 +504,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, self.model.embed_tokens.weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"])) apply_torchao_config_(self, params_dict, set(["proj.weight"]))
......
...@@ -315,7 +315,7 @@ class XverseForCausalLM(nn.Module): ...@@ -315,7 +315,7 @@ class XverseForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights( def load_weights(
......
...@@ -390,7 +390,7 @@ class XverseMoeForCausalLM(nn.Module): ...@@ -390,7 +390,7 @@ class XverseMoeForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch) hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -20,6 +20,7 @@ import random ...@@ -20,6 +20,7 @@ import random
import tempfile import tempfile
from typing import List, Optional from typing import List, Optional
from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.utils import ( from sglang.srt.utils import (
get_amdgpu_memory_capacity, get_amdgpu_memory_capacity,
get_nvgpu_memory_capacity, get_nvgpu_memory_capacity,
...@@ -204,6 +205,12 @@ class ServerArgs: ...@@ -204,6 +205,12 @@ class ServerArgs:
"Overlap schedule is disabled." "Overlap schedule is disabled."
) )
# GGUF
if (
self.load_format == "auto" or self.load_format == "gguf"
) and check_gguf_file(self.model_path):
self.quantization = self.load_format = "gguf"
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args # Model and port args
...@@ -243,7 +250,7 @@ class ServerArgs: ...@@ -243,7 +250,7 @@ class ServerArgs:
"--load-format", "--load-format",
type=str, type=str,
default=ServerArgs.load_format, default=ServerArgs.load_format,
choices=["auto", "pt", "safetensors", "npcache", "dummy"], choices=["auto", "pt", "safetensors", "npcache", "dummy", "gguf"],
help="The format of the model weights to load. " help="The format of the model weights to load. "
'"auto" will try to load the weights in the safetensors format ' '"auto" will try to load the weights in the safetensors format '
"and fall back to the pytorch bin format if safetensors format " "and fall back to the pytorch bin format if safetensors format "
...@@ -253,7 +260,8 @@ class ServerArgs: ...@@ -253,7 +260,8 @@ class ServerArgs:
'"npcache" will load the weights in pytorch format and store ' '"npcache" will load the weights in pytorch format and store '
"a numpy cache to speed up the loading. " "a numpy cache to speed up the loading. "
'"dummy" will initialize the weights with random values, ' '"dummy" will initialize the weights with random values, '
"which is mainly for profiling.", "which is mainly for profiling."
'"gguf" will load the weights in the gguf format. ',
) )
parser.add_argument( parser.add_argument(
"--trust-remote-code", "--trust-remote-code",
...@@ -293,6 +301,7 @@ class ServerArgs: ...@@ -293,6 +301,7 @@ class ServerArgs:
"gptq_marlin", "gptq_marlin",
"awq_marlin", "awq_marlin",
"bitsandbytes", "bitsandbytes",
"gguf",
], ],
help="The quantization method.", help="The quantization method.",
) )
......
...@@ -557,6 +557,29 @@ def monkey_patch_vllm_all_gather(reverse: bool = False): ...@@ -557,6 +557,29 @@ def monkey_patch_vllm_all_gather(reverse: bool = False):
setattr(GroupCoordinator, "all_gather", all_gather) setattr(GroupCoordinator, "all_gather", all_gather)
def monkey_patch_vllm_gguf_config():
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.gguf import (
GGUFConfig,
GGUFEmbeddingMethod,
GGUFLinearMethod,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
def get_quant_method_with_embedding_replaced(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
# patch to own VocabParallelEmbedding
return GGUFEmbeddingMethod(self)
return None
setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)
def maybe_set_triton_cache_manager() -> None: def maybe_set_triton_cache_manager() -> None:
"""Set environment variable to tell Triton to use a """Set environment variable to tell Triton to use a
custom cache manager""" custom cache manager"""
......
...@@ -15,6 +15,7 @@ suites = { ...@@ -15,6 +15,7 @@ suites = {
"test_double_sparsity.py", "test_double_sparsity.py",
"test_embedding_openai_server.py", "test_embedding_openai_server.py",
"test_eval_accuracy_mini.py", "test_eval_accuracy_mini.py",
"test_gguf.py",
"test_input_embeddings.py", "test_input_embeddings.py",
"test_json_constrained.py", "test_json_constrained.py",
"test_large_max_new_tokens.py", "test_large_max_new_tokens.py",
......
...@@ -16,7 +16,7 @@ from sglang.test.test_utils import ( ...@@ -16,7 +16,7 @@ from sglang.test.test_utils import (
from sglang.utils import terminate_process from sglang.utils import terminate_process
class TestGetParameterByName(unittest.TestCase): class TestGetWeightsByName(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
......
import unittest
from huggingface_hub import hf_hub_download
import sglang as sgl
class TestGGUF(unittest.TestCase):
def test_models(self):
prompt = "Today is a sunny day and I like"
sampling_params = {"temperature": 0, "max_new_tokens": 8}
model_path = hf_hub_download(
"Qwen/Qwen2-1.5B-Instruct-GGUF",
filename="qwen2-1_5b-instruct-q4_k_m.gguf",
)
engine = sgl.Engine(model_path=model_path, random_seed=42)
outputs = engine.generate(prompt, sampling_params)["text"]
engine.shutdown()
self.assertEqual(outputs, " it. I have a lot of work")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment