Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4936be8a
Unverified
Commit
4936be8a
authored
Nov 30, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 30, 2024
Browse files
Revert "Revert "[FEAT] Support GGUF format"" (#2287)
parent
1bfa511b
Changes
41
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
136 additions
and
89 deletions
+136
-89
docs/requirements.txt
docs/requirements.txt
+1
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+35
-1
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+17
-3
python/sglang/srt/layers/vocab_parallel_embedding.py
python/sglang/srt/layers/vocab_parallel_embedding.py
+1
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-0
python/sglang/srt/models/baichuan.py
python/sglang/srt/models/baichuan.py
+6
-5
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+1
-1
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+1
-1
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+1
-1
python/sglang/srt/models/deepseek.py
python/sglang/srt/models/deepseek.py
+1
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-1
python/sglang/srt/models/exaone.py
python/sglang/srt/models/exaone.py
+1
-1
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+1
-1
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+1
-1
python/sglang/srt/models/gpt2.py
python/sglang/srt/models/gpt2.py
+1
-1
python/sglang/srt/models/gpt_bigcode.py
python/sglang/srt/models/gpt_bigcode.py
+1
-1
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+1
-1
python/sglang/srt/models/internlm2.py
python/sglang/srt/models/internlm2.py
+1
-1
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+58
-63
python/sglang/srt/models/minicpm.py
python/sglang/srt/models/minicpm.py
+3
-5
No files found.
docs/requirements.txt
View file @
4936be8a
...
@@ -15,3 +15,4 @@ sphinx-copybutton
...
@@ -15,3 +15,4 @@ sphinx-copybutton
sphinx-tabs
sphinx-tabs
sphinxcontrib-mermaid
sphinxcontrib-mermaid
urllib3<2.0.0
urllib3<2.0.0
gguf>=0.10.0
python/sglang/srt/hf_transformers_utils.py
View file @
4936be8a
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
import
contextlib
import
contextlib
import
os
import
os
import
warnings
import
warnings
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
,
Type
,
Union
from
typing
import
Dict
,
Optional
,
Type
,
Union
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
...
@@ -27,6 +28,7 @@ from transformers import (
...
@@ -27,6 +28,7 @@ from transformers import (
PreTrainedTokenizer
,
PreTrainedTokenizer
,
PreTrainedTokenizerFast
,
PreTrainedTokenizerFast
,
)
)
from
transformers.models.auto.modeling_auto
import
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
try
:
try
:
from
vllm.transformers_utils.configs
import
ChatGLMConfig
,
DbrxConfig
from
vllm.transformers_utils.configs
import
ChatGLMConfig
,
DbrxConfig
...
@@ -60,15 +62,29 @@ def get_config(
...
@@ -60,15 +62,29 @@ def get_config(
trust_remote_code
:
bool
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
model_override_args
:
Optional
[
dict
]
=
None
,
model_override_args
:
Optional
[
dict
]
=
None
,
**
kwargs
,
):
):
is_gguf
=
check_gguf_file
(
model
)
if
is_gguf
:
kwargs
[
"gguf_file"
]
=
model
model
=
Path
(
model
).
parent
config
=
AutoConfig
.
from_pretrained
(
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
)
if
config
.
model_type
in
_CONFIG_REGISTRY
:
if
config
.
model_type
in
_CONFIG_REGISTRY
:
config_class
=
_CONFIG_REGISTRY
[
config
.
model_type
]
config_class
=
_CONFIG_REGISTRY
[
config
.
model_type
]
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
)
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
)
if
model_override_args
:
if
model_override_args
:
config
.
update
(
model_override_args
)
config
.
update
(
model_override_args
)
# Special architecture mapping check for GGUF models
if
is_gguf
:
if
config
.
model_type
not
in
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
raise
RuntimeError
(
f
"Can't get gguf config for
{
config
.
model_type
}
."
)
model_type
=
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
[
config
.
model_type
]
config
.
update
({
"architectures"
:
[
model_type
]})
return
config
return
config
...
@@ -123,6 +139,11 @@ def get_tokenizer(
...
@@ -123,6 +139,11 @@ def get_tokenizer(
raise
ValueError
(
"Cannot use the fast tokenizer in slow tokenizer mode."
)
raise
ValueError
(
"Cannot use the fast tokenizer in slow tokenizer mode."
)
kwargs
[
"use_fast"
]
=
False
kwargs
[
"use_fast"
]
=
False
is_gguf
=
check_gguf_file
(
tokenizer_name
)
if
is_gguf
:
kwargs
[
"gguf_file"
]
=
tokenizer_name
tokenizer_name
=
Path
(
tokenizer_name
).
parent
try
:
try
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_name
,
tokenizer_name
,
...
@@ -195,3 +216,16 @@ def attach_additional_stop_token_ids(tokenizer):
...
@@ -195,3 +216,16 @@ def attach_additional_stop_token_ids(tokenizer):
)
)
else
:
else
:
tokenizer
.
additional_stop_token_ids
=
None
tokenizer
.
additional_stop_token_ids
=
None
def
check_gguf_file
(
model
:
Union
[
str
,
os
.
PathLike
])
->
bool
:
"""Check if the file is a GGUF model."""
model
=
Path
(
model
)
if
not
model
.
is_file
():
return
False
elif
model
.
suffix
==
".gguf"
:
return
True
with
open
(
model
,
"rb"
)
as
f
:
header
=
f
.
read
(
4
)
return
header
==
b
"GGUF"
python/sglang/srt/layers/logits_processor.py
View file @
4936be8a
...
@@ -23,6 +23,7 @@ from vllm.distributed import (
...
@@ -23,6 +23,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
)
)
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
...
@@ -163,7 +164,7 @@ class LogitsProcessor(nn.Module):
...
@@ -163,7 +164,7 @@ class LogitsProcessor(nn.Module):
self
,
self
,
input_ids
,
input_ids
,
hidden_states
,
hidden_states
,
weight
,
lm_head
:
VocabParallelEmbedding
,
logits_metadata
:
Union
[
LogitsMetadata
,
ForwardBatch
],
logits_metadata
:
Union
[
LogitsMetadata
,
ForwardBatch
],
):
):
if
isinstance
(
logits_metadata
,
ForwardBatch
):
if
isinstance
(
logits_metadata
,
ForwardBatch
):
...
@@ -178,7 +179,7 @@ class LogitsProcessor(nn.Module):
...
@@ -178,7 +179,7 @@ class LogitsProcessor(nn.Module):
last_index
=
torch
.
cumsum
(
logits_metadata
.
extend_seq_lens
,
dim
=
0
)
-
1
last_index
=
torch
.
cumsum
(
logits_metadata
.
extend_seq_lens
,
dim
=
0
)
-
1
last_hidden
=
hidden_states
[
last_index
]
last_hidden
=
hidden_states
[
last_index
]
last_logits
=
torch
.
matmul
(
last_hidden
,
weight
.
T
)
last_logits
=
self
.
_get_logits
(
last_hidden
,
lm_head
)
if
self
.
do_tensor_parallel_all_gather
:
if
self
.
do_tensor_parallel_all_gather
:
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
...
@@ -229,7 +230,7 @@ class LogitsProcessor(nn.Module):
...
@@ -229,7 +230,7 @@ class LogitsProcessor(nn.Module):
# Compute the logits and logprobs for all required tokens
# Compute the logits and logprobs for all required tokens
states
=
torch
.
cat
(
states
,
dim
=
0
)
states
=
torch
.
cat
(
states
,
dim
=
0
)
all_logits
=
torch
.
matmul
(
states
,
weight
.
T
)
all_logits
=
self
.
_get_logits
(
states
,
lm_head
)
if
self
.
do_tensor_parallel_all_gather
:
if
self
.
do_tensor_parallel_all_gather
:
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
...
@@ -276,6 +277,19 @@ class LogitsProcessor(nn.Module):
...
@@ -276,6 +277,19 @@ class LogitsProcessor(nn.Module):
output_top_logprobs
=
output_top_logprobs
,
output_top_logprobs
=
output_top_logprobs
,
)
)
def
_get_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
lm_head
:
VocabParallelEmbedding
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
hasattr
(
lm_head
,
"weight"
):
logits
=
torch
.
matmul
(
hidden_states
,
lm_head
.
weight
.
T
)
else
:
# GGUF models
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
embedding_bias
)
return
logits
def
test
():
def
test
():
all_logprobs
=
torch
.
tensor
(
all_logprobs
=
torch
.
tensor
(
...
...
python/sglang/srt/layers/vocab_parallel_embedding.py
View file @
4936be8a
...
@@ -222,6 +222,7 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -222,6 +222,7 @@ class VocabParallelEmbedding(torch.nn.Module):
enable_tp
:
bool
=
True
,
enable_tp
:
bool
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
quant_config
=
quant_config
self
.
enable_tp
=
enable_tp
self
.
enable_tp
=
enable_tp
if
self
.
enable_tp
:
if
self
.
enable_tp
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
4936be8a
...
@@ -59,6 +59,7 @@ from sglang.srt.utils import (
...
@@ -59,6 +59,7 @@ from sglang.srt.utils import (
enable_show_time_cost
,
enable_show_time_cost
,
get_available_gpu_memory
,
get_available_gpu_memory
,
is_hip
,
is_hip
,
monkey_patch_vllm_gguf_config
,
monkey_patch_vllm_model_config
,
monkey_patch_vllm_model_config
,
monkey_patch_vllm_p2p_access_check
,
monkey_patch_vllm_p2p_access_check
,
set_cpu_offload_max_bytes
,
set_cpu_offload_max_bytes
,
...
@@ -297,6 +298,8 @@ class ModelRunner:
...
@@ -297,6 +298,8 @@ class ModelRunner:
download_dir
=
self
.
server_args
.
download_dir
,
download_dir
=
self
.
server_args
.
download_dir
,
)
)
monkey_patch_vllm_model_config
()
monkey_patch_vllm_model_config
()
if
self
.
server_args
.
load_format
==
"gguf"
:
monkey_patch_vllm_gguf_config
()
self
.
vllm_model_config
=
VllmModelConfig
(
**
self
.
get_model_config_params
())
self
.
vllm_model_config
=
VllmModelConfig
(
**
self
.
get_model_config_params
())
if
self
.
model_config
.
model_override_args
is
not
None
:
if
self
.
model_config
.
model_override_args
is
not
None
:
self
.
vllm_model_config
.
hf_config
.
update
(
self
.
vllm_model_config
.
hf_config
.
update
(
...
...
python/sglang/srt/models/baichuan.py
View file @
4936be8a
...
@@ -338,11 +338,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
...
@@ -338,11 +338,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
,
quant_config
)
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
def
forward
(
def
forward
(
...
@@ -353,7 +354,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
...
@@ -353,7 +354,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/chatglm.py
View file @
4936be8a
...
@@ -378,7 +378,7 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -378,7 +378,7 @@ class ChatGLMForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
forward_batch
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/commandr.py
View file @
4936be8a
...
@@ -339,7 +339,7 @@ class CohereForCausalLM(nn.Module):
...
@@ -339,7 +339,7 @@ class CohereForCausalLM(nn.Module):
forward_batch
,
forward_batch
,
)
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/dbrx.py
View file @
4936be8a
...
@@ -390,7 +390,7 @@ class DbrxForCausalLM(nn.Module):
...
@@ -390,7 +390,7 @@ class DbrxForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
forward_batch
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/deepseek.py
View file @
4936be8a
...
@@ -394,7 +394,7 @@ class DeepseekForCausalLM(nn.Module):
...
@@ -394,7 +394,7 @@ class DeepseekForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
4936be8a
...
@@ -763,7 +763,7 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -763,7 +763,7 @@ class DeepseekV2ForCausalLM(nn.Module):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
)
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
not
forward_batch
.
forward_mode
.
is_idle
():
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/exaone.py
View file @
4936be8a
...
@@ -314,7 +314,7 @@ class ExaoneForCausalLM(nn.Module):
...
@@ -314,7 +314,7 @@ class ExaoneForCausalLM(nn.Module):
input_ids
,
positions
,
forward_batch
,
input_embeds
input_ids
,
positions
,
forward_batch
,
input_embeds
)
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/gemma.py
View file @
4936be8a
...
@@ -298,7 +298,7 @@ class GemmaForCausalLM(nn.Module):
...
@@ -298,7 +298,7 @@ class GemmaForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/gemma2.py
View file @
4936be8a
...
@@ -363,7 +363,7 @@ class Gemma2ForCausalLM(nn.Module):
...
@@ -363,7 +363,7 @@ class Gemma2ForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
,
forward_batch
)
)
def
get_attention_sliding_window_size
(
self
):
def
get_attention_sliding_window_size
(
self
):
...
...
python/sglang/srt/models/gpt2.py
View file @
4936be8a
...
@@ -247,7 +247,7 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -247,7 +247,7 @@ class GPT2LMHeadModel(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
forward_batch
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/gpt_bigcode.py
View file @
4936be8a
...
@@ -271,7 +271,7 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -271,7 +271,7 @@ class GPTBigCodeForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
forward_batch
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/grok.py
View file @
4936be8a
...
@@ -304,7 +304,7 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -304,7 +304,7 @@ class Grok1ForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/internlm2.py
View file @
4936be8a
...
@@ -270,7 +270,7 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -270,7 +270,7 @@ class InternLM2ForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
output
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
output
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/llama.py
View file @
4936be8a
...
@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
make_layers
from
sglang.srt.utils
import
make_layers
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -258,6 +259,7 @@ class LlamaModel(nn.Module):
...
@@ -258,6 +259,7 @@ class LlamaModel(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
)
)
self
.
layers
=
make_layers
(
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
...
@@ -305,7 +307,12 @@ class LlamaForCausalLM(nn.Module):
...
@@ -305,7 +307,12 @@ class LlamaForCausalLM(nn.Module):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
stacked_params_mapping
=
[
self
.
stacked_params_mapping
=
[
...
@@ -329,7 +336,7 @@ class LlamaForCausalLM(nn.Module):
...
@@ -329,7 +336,7 @@ class LlamaForCausalLM(nn.Module):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
if
not
get_embedding
:
if
not
get_embedding
:
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
else
:
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
...
@@ -373,7 +380,6 @@ class LlamaForCausalLM(nn.Module):
...
@@ -373,7 +380,6 @@ class LlamaForCausalLM(nn.Module):
return
len
(
params_dict
)
return
len
(
params_dict
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
embed_tokens_weight
=
None
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".q_proj"
,
"q"
),
...
@@ -385,12 +391,6 @@ class LlamaForCausalLM(nn.Module):
...
@@ -385,12 +391,6 @@ class LlamaForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
load_tie_word_embeddings
=
(
hasattr
(
self
.
config
,
"tie_word_embeddings"
)
and
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
params_dict
)
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
continue
continue
...
@@ -423,16 +423,6 @@ class LlamaForCausalLM(nn.Module):
...
@@ -423,16 +423,6 @@ class LlamaForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
load_tie_word_embeddings
and
name
==
"model.embed_tokens.weight"
:
embed_tokens_weight
=
loaded_weight
if
load_tie_word_embeddings
:
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param
=
self
.
lm_head
.
weight
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
if
embed_tokens_weight
is
not
None
:
weight_loader
(
param
,
embed_tokens_weight
)
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
def
get_weights_by_name
(
def
get_weights_by_name
(
...
@@ -444,6 +434,17 @@ class LlamaForCausalLM(nn.Module):
...
@@ -444,6 +434,17 @@ class LlamaForCausalLM(nn.Module):
For optimized performance, please use torch.save and torch.load.
For optimized performance, please use torch.save and torch.load.
"""
"""
try
:
try
:
if
name
==
"lm_head.weight"
and
self
.
config
.
tie_word_embeddings
:
logger
.
info
(
"word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
)
return
(
self
.
model
.
embed_tokens
.
weight
.
cpu
()
.
to
(
torch
.
float32
)
.
numpy
()
.
tolist
()[:
truncate_size
]
)
mapped_name
=
name
mapped_name
=
name
mapped_shard_id
=
None
mapped_shard_id
=
None
for
param_name
,
weight_name
,
shard_id
in
self
.
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
self
.
stacked_params_mapping
:
...
@@ -452,54 +453,48 @@ class LlamaForCausalLM(nn.Module):
...
@@ -452,54 +453,48 @@ class LlamaForCausalLM(nn.Module):
mapped_shard_id
=
shard_id
mapped_shard_id
=
shard_id
break
break
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
if
mapped_name
in
params_dict
:
param
=
params_dict
[
mapped_name
]
param
=
params_dict
[
mapped_name
]
if
mapped_shard_id
is
not
None
:
if
mapped_shard_id
is
not
None
:
if
mapped_shard_id
in
[
"q"
,
"k"
,
"v"
]:
if
mapped_shard_id
in
[
"q"
,
"k"
,
"v"
]:
num_heads
=
self
.
config
.
num_attention_heads
//
tp_size
num_heads
=
self
.
config
.
num_attention_heads
//
tp_size
num_kv_heads
=
self
.
config
.
num_key_value_heads
//
tp_size
num_kv_heads
=
self
.
config
.
num_key_value_heads
//
tp_size
head_dim
=
(
head_dim
=
(
self
.
config
.
hidden_size
//
self
.
config
.
num_attention_heads
self
.
config
.
hidden_size
//
self
.
config
.
num_attention_heads
)
)
if
mapped_shard_id
==
"q"
:
if
mapped_shard_id
==
"q"
:
offset
=
0
offset
=
0
size
=
num_heads
*
head_dim
size
=
num_heads
*
head_dim
elif
mapped_shard_id
==
"k"
:
elif
mapped_shard_id
==
"k"
:
offset
=
num_heads
*
head_dim
offset
=
num_heads
*
head_dim
size
=
num_kv_heads
*
head_dim
size
=
num_kv_heads
*
head_dim
elif
mapped_shard_id
==
"v"
:
elif
mapped_shard_id
==
"v"
:
offset
=
(
num_heads
+
num_kv_heads
)
*
head_dim
offset
=
(
num_heads
+
num_kv_heads
)
*
head_dim
size
=
num_kv_heads
*
head_dim
size
=
num_kv_heads
*
head_dim
weight
=
param
.
data
.
narrow
(
0
,
offset
,
size
)
weight
=
param
.
data
.
narrow
(
0
,
offset
,
size
)
elif
mapped_shard_id
in
[
0
,
1
]:
elif
mapped_shard_id
in
[
0
,
1
]:
intermediate_size
=
self
.
config
.
intermediate_size
intermediate_size
=
self
.
config
.
intermediate_size
slice_size
=
intermediate_size
//
tp_size
hidden_size
=
self
.
config
.
hidden_size
if
mapped_shard_id
==
0
:
# gate_proj
slice_size
=
intermediate_size
//
tp_size
offset
=
0
if
mapped_shard_id
==
0
:
# gate_proj
size
=
slice_size
offset
=
0
elif
mapped_shard_id
==
1
:
# up_proj
size
=
slice_size
offset
=
slice_size
elif
mapped_shard_id
==
1
:
# up_proj
size
=
slice_size
offset
=
slice_size
size
=
slice_size
weight
=
param
.
data
.
narrow
(
0
,
offset
,
size
)
weight
=
param
.
data
.
narrow
(
0
,
offset
,
size
)
else
:
weight
=
param
.
data
else
:
else
:
weight
=
param
.
data
weight
=
param
.
data
if
tp_size
>
1
and
(
"o_proj"
in
name
or
"down_proj"
in
name
):
gathered_weights
=
[
torch
.
zeros_like
(
weight
)
for
_
in
range
(
tp_size
)
]
torch
.
distributed
.
all_gather
(
gathered_weights
,
weight
)
weight
=
torch
.
cat
(
gathered_weights
,
dim
=
1
)
return
weight
.
cpu
().
to
(
torch
.
float32
).
numpy
().
tolist
()[:
truncate_size
]
else
:
else
:
return
None
weight
=
param
.
data
if
tp_size
>
1
and
(
"o_proj"
in
name
or
"down_proj"
in
name
):
except
Exception
as
e
:
gathered_weights
=
[
torch
.
zeros_like
(
weight
)
for
_
in
range
(
tp_size
)]
torch
.
distributed
.
all_gather
(
gathered_weights
,
weight
)
weight
=
torch
.
cat
(
gathered_weights
,
dim
=
1
)
return
weight
.
cpu
().
to
(
torch
.
float32
).
numpy
().
tolist
()[:
truncate_size
]
except
Exception
:
logger
.
error
(
logger
.
error
(
f
"Error getting weights by name
{
name
}
in LlamaForCausalLM:
{
e
}
"
f
"Error getting weights by name
{
name
}
in LlamaForCausalLM:
{
get_exception_traceback
()
}
"
)
)
return
None
return
None
...
...
python/sglang/srt/models/minicpm.py
View file @
4936be8a
...
@@ -308,12 +308,10 @@ class MiniCPMForCausalLM(nn.Module):
...
@@ -308,12 +308,10 @@ class MiniCPMForCausalLM(nn.Module):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
hidden_states
/
self
.
scale_width
hidden_states
=
hidden_states
/
self
.
scale_width
if
self
.
config
.
tie_word_embeddings
:
if
self
.
config
.
tie_word_embeddings
:
lm_head
_weight
=
self
.
model
.
embed_tokens
.
weight
lm_head
=
self
.
model
.
embed_tokens
else
:
else
:
lm_head_weight
=
self
.
lm_head
.
weight
lm_head
=
self
.
lm_head
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
lm_head
,
forward_batch
)
input_ids
,
hidden_states
,
lm_head_weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment