Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4936be8a
Unverified
Commit
4936be8a
authored
Nov 30, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 30, 2024
Browse files
Revert "Revert "[FEAT] Support GGUF format"" (#2287)
parent
1bfa511b
Changes
41
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
95 additions
and
45 deletions
+95
-45
python/sglang/srt/models/minicpm3.py
python/sglang/srt/models/minicpm3.py
+3
-5
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+1
-1
python/sglang/srt/models/mixtral_quant.py
python/sglang/srt/models/mixtral_quant.py
+1
-1
python/sglang/srt/models/mllama.py
python/sglang/srt/models/mllama.py
+1
-1
python/sglang/srt/models/olmo.py
python/sglang/srt/models/olmo.py
+1
-6
python/sglang/srt/models/olmoe.py
python/sglang/srt/models/olmoe.py
+1
-1
python/sglang/srt/models/phi3_small.py
python/sglang/srt/models/phi3_small.py
+5
-2
python/sglang/srt/models/qwen.py
python/sglang/srt/models/qwen.py
+1
-1
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+9
-7
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+1
-1
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+1
-3
python/sglang/srt/models/stablelm.py
python/sglang/srt/models/stablelm.py
+1
-1
python/sglang/srt/models/torch_native_llama.py
python/sglang/srt/models/torch_native_llama.py
+5
-10
python/sglang/srt/models/xverse.py
python/sglang/srt/models/xverse.py
+1
-1
python/sglang/srt/models/xverse_moe.py
python/sglang/srt/models/xverse_moe.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+11
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+23
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_get_weights_by_name.py
test/srt/test_get_weights_by_name.py
+1
-1
test/srt/test_gguf.py
test/srt/test_gguf.py
+26
-0
No files found.
python/sglang/srt/models/minicpm3.py
View file @
4936be8a
...
...
@@ -585,12 +585,10 @@ class MiniCPM3ForCausalLM(nn.Module):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
hidden_states
/
self
.
scale_width
if
self
.
config
.
tie_word_embeddings
:
lm_head
_weight
=
self
.
model
.
embed_tokens
.
weight
lm_head
=
self
.
model
.
embed_tokens
else
:
lm_head_weight
=
self
.
lm_head
.
weight
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
lm_head_weight
,
forward_batch
)
lm_head
=
self
.
lm_head
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
lm_head
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/mixtral.py
View file @
4936be8a
...
...
@@ -310,7 +310,7 @@ class MixtralForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/mixtral_quant.py
View file @
4936be8a
...
...
@@ -343,7 +343,7 @@ class QuantMixtralForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/mllama.py
View file @
4936be8a
...
...
@@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module):
skip_cross_attention
=
skip_cross_attention
,
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
language_model
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
language_model
.
lm_head
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/olmo.py
View file @
4936be8a
...
...
@@ -306,7 +306,7 @@ class OlmoForCausalLM(nn.Module):
input_embeds
=
input_embeds
,
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
@@ -326,11 +326,6 @@ class OlmoForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
...
...
python/sglang/srt/models/olmoe.py
View file @
4936be8a
...
...
@@ -321,7 +321,7 @@ class OlmoeForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/phi3_small.py
View file @
4936be8a
...
...
@@ -397,10 +397,13 @@ class Phi3SmallForCausalLM(nn.Module):
def
compute_logits
(
self
,
input_ids
:
torch
.
LongTensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
logits
=
self
.
logits_processor
(
input_ids
,
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
if
self
.
dummy_token_indices
is
not
None
and
logits
is
not
None
:
logits
.
index_fill_
(
-
1
,
self
.
dummy_token_indices
,
-
torch
.
inf
)
return
logits
...
...
@@ -422,7 +425,7 @@ class Phi3SmallForCausalLM(nn.Module):
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
else
:
...
...
python/sglang/srt/models/qwen.py
View file @
4936be8a
...
...
@@ -260,7 +260,7 @@ class QWenLMHeadModel(nn.Module):
):
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/qwen2.py
View file @
4936be8a
...
...
@@ -230,6 +230,7 @@ class Qwen2Model(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
)
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
...
...
@@ -276,7 +277,12 @@ class Qwen2ForCausalLM(nn.Module):
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen2Model
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
...
...
@@ -292,7 +298,7 @@ class Qwen2ForCausalLM(nn.Module):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
...
...
@@ -306,6 +312,7 @@ class Qwen2ForCausalLM(nn.Module):
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
...
...
@@ -335,11 +342,6 @@ class Qwen2ForCausalLM(nn.Module):
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
(
self
.
config
.
tie_word_embeddings
and
name
==
"model.embed_tokens.weight"
):
weight_loader
(
params_dict
[
"lm_head.weight"
],
loaded_weight
)
EntryClass
=
Qwen2ForCausalLM
python/sglang/srt/models/qwen2_moe.py
View file @
4936be8a
...
...
@@ -376,7 +376,7 @@ class Qwen2MoeForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
4936be8a
...
...
@@ -668,7 +668,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
...
...
@@ -686,8 +686,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
...
...
python/sglang/srt/models/stablelm.py
View file @
4936be8a
...
...
@@ -261,7 +261,7 @@ class StableLmForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/torch_native_llama.py
View file @
4936be8a
...
...
@@ -396,7 +396,10 @@ class TorchNativeLlamaForCausalLM(nn.Module):
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
supports_torch_tp
=
True
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
# turning off autotune for fp8dq since it doesn't give speedup and
...
...
@@ -413,7 +416,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
)
->
LogitsProcessorOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
def
get_hidden_dim
(
self
,
module_name
):
...
...
@@ -501,14 +504,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
(
hasattr
(
self
.
config
,
"tie_word_embeddings"
)
and
self
.
config
.
tie_word_embeddings
):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param
=
self
.
lm_head
.
weight
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
self
.
model
.
embed_tokens
.
weight
)
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
...
...
python/sglang/srt/models/xverse.py
View file @
4936be8a
...
...
@@ -315,7 +315,7 @@ class XverseForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
def
load_weights
(
...
...
python/sglang/srt/models/xverse_moe.py
View file @
4936be8a
...
...
@@ -390,7 +390,7 @@ class XverseMoeForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/server_args.py
View file @
4936be8a
...
...
@@ -20,6 +20,7 @@ import random
import
tempfile
from
typing
import
List
,
Optional
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
from
sglang.srt.utils
import
(
get_amdgpu_memory_capacity
,
get_nvgpu_memory_capacity
,
...
...
@@ -204,6 +205,12 @@ class ServerArgs:
"Overlap schedule is disabled."
)
# GGUF
if
(
self
.
load_format
==
"auto"
or
self
.
load_format
==
"gguf"
)
and
check_gguf_file
(
self
.
model_path
):
self
.
quantization
=
self
.
load_format
=
"gguf"
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
# Model and port args
...
...
@@ -243,7 +250,7 @@ class ServerArgs:
"--load-format"
,
type
=
str
,
default
=
ServerArgs
.
load_format
,
choices
=
[
"auto"
,
"pt"
,
"safetensors"
,
"npcache"
,
"dummy"
],
choices
=
[
"auto"
,
"pt"
,
"safetensors"
,
"npcache"
,
"dummy"
,
"gguf"
],
help
=
"The format of the model weights to load. "
'"auto" will try to load the weights in the safetensors format '
"and fall back to the pytorch bin format if safetensors format "
...
...
@@ -253,7 +260,8 @@ class ServerArgs:
'"npcache" will load the weights in pytorch format and store '
"a numpy cache to speed up the loading. "
'"dummy" will initialize the weights with random values, '
"which is mainly for profiling."
,
"which is mainly for profiling."
'"gguf" will load the weights in the gguf format. '
,
)
parser
.
add_argument
(
"--trust-remote-code"
,
...
...
@@ -293,6 +301,7 @@ class ServerArgs:
"gptq_marlin"
,
"awq_marlin"
,
"bitsandbytes"
,
"gguf"
,
],
help
=
"The quantization method."
,
)
...
...
python/sglang/srt/utils.py
View file @
4936be8a
...
...
@@ -557,6 +557,29 @@ def monkey_patch_vllm_all_gather(reverse: bool = False):
setattr
(
GroupCoordinator
,
"all_gather"
,
all_gather
)
def
monkey_patch_vllm_gguf_config
():
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.gguf
import
(
GGUFConfig
,
GGUFEmbeddingMethod
,
GGUFLinearMethod
,
)
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
def
get_quant_method_with_embedding_replaced
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
return
GGUFLinearMethod
(
self
)
elif
isinstance
(
layer
,
VocabParallelEmbedding
):
# patch to own VocabParallelEmbedding
return
GGUFEmbeddingMethod
(
self
)
return
None
setattr
(
GGUFConfig
,
"get_quant_method"
,
get_quant_method_with_embedding_replaced
)
def
maybe_set_triton_cache_manager
()
->
None
:
"""Set environment variable to tell Triton to use a
custom cache manager"""
...
...
test/srt/run_suite.py
View file @
4936be8a
...
...
@@ -15,6 +15,7 @@ suites = {
"test_double_sparsity.py"
,
"test_embedding_openai_server.py"
,
"test_eval_accuracy_mini.py"
,
"test_gguf.py"
,
"test_input_embeddings.py"
,
"test_json_constrained.py"
,
"test_large_max_new_tokens.py"
,
...
...
test/srt/test_get_
parameter
_by_name.py
→
test/srt/test_get_
weights
_by_name.py
View file @
4936be8a
...
...
@@ -16,7 +16,7 @@ from sglang.test.test_utils import (
from
sglang.utils
import
terminate_process
class
TestGet
Parameter
ByName
(
unittest
.
TestCase
):
class
TestGet
Weights
ByName
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
...
...
test/srt/test_gguf.py
0 → 100644
View file @
4936be8a
import
unittest
from
huggingface_hub
import
hf_hub_download
import
sglang
as
sgl
class
TestGGUF
(
unittest
.
TestCase
):
def
test_models
(
self
):
prompt
=
"Today is a sunny day and I like"
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
8
}
model_path
=
hf_hub_download
(
"Qwen/Qwen2-1.5B-Instruct-GGUF"
,
filename
=
"qwen2-1_5b-instruct-q4_k_m.gguf"
,
)
engine
=
sgl
.
Engine
(
model_path
=
model_path
,
random_seed
=
42
)
outputs
=
engine
.
generate
(
prompt
,
sampling_params
)[
"text"
]
engine
.
shutdown
()
self
.
assertEqual
(
outputs
,
" it. I have a lot of work"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment