Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
993956c6
Unverified
Commit
993956c6
authored
Dec 11, 2024
by
Fred Reiss
Committed by
GitHub
Dec 11, 2024
Browse files
Add support for IBM Granite 3.x models (#2437)
parent
f8548295
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
562 additions
and
1 deletion
+562
-1
docs/references/supported_models.md
docs/references/supported_models.md
+1
-0
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+32
-0
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+11
-1
python/sglang/srt/models/granite.py
python/sglang/srt/models/granite.py
+517
-0
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+1
-0
No files found.
docs/references/supported_models.md
View file @
993956c6
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
-
SmolLM
-
SmolLM
-
GLM-4
-
GLM-4
-
Phi-3-Small
-
Phi-3-Small
-
IBM Granite 3
## Embedding Models
## Embedding Models
...
...
python/sglang/lang/chat_template.py
View file @
993956c6
...
@@ -320,6 +320,28 @@ register_chat_template(
...
@@ -320,6 +320,28 @@ register_chat_template(
)
)
)
)
register_chat_template
(
ChatTemplate
(
name
=
"granite-3-instruct"
,
default_system_prompt
=
None
,
role_prefix_and_suffix
=
{
"system"
:
(
"<|start_of_role|>system<|end_of_role|>"
,
"<|end_of_text|>"
,
),
"user"
:
(
"<|start_of_role|>user<|end_of_role|>"
,
"<|end_of_text|>"
,
),
"assistant"
:
(
"<|start_of_role|>assistant<|end_of_role|>"
,
"<|end_of_text|>"
,
),
},
stop_str
=
(
"<|end_of_text|>"
,),
)
)
@
register_chat_template_matching_function
@
register_chat_template_matching_function
def
match_dbrx
(
model_path
:
str
):
def
match_dbrx
(
model_path
:
str
):
...
@@ -402,6 +424,16 @@ def match_c4ai_command_r(model_path: str):
...
@@ -402,6 +424,16 @@ def match_c4ai_command_r(model_path: str):
return
get_chat_template
(
"c4ai-command-r"
)
return
get_chat_template
(
"c4ai-command-r"
)
@
register_chat_template_matching_function
def
match_granite_instruct
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
# When future versions of Granite are released, this code may
# need to be updated. For now, assume that the Granite 3.0
# template works across the board.
if
"granite"
in
model_path
and
"instruct"
in
model_path
:
return
get_chat_template
(
"granite-3-instruct"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
messages
=
[
messages
=
[
{
"role"
:
"system"
,
"content"
:
None
},
# None means default
{
"role"
:
"system"
,
"content"
:
None
},
# None means default
...
...
python/sglang/srt/layers/logits_processor.py
View file @
993956c6
...
@@ -91,9 +91,12 @@ class LogitsMetadata:
...
@@ -91,9 +91,12 @@ class LogitsMetadata:
class
LogitsProcessor
(
nn
.
Module
):
class
LogitsProcessor
(
nn
.
Module
):
def
__init__
(
self
,
config
,
skip_all_gather
:
bool
=
False
):
def
__init__
(
self
,
config
,
skip_all_gather
:
bool
=
False
,
logit_scale
:
Optional
[
float
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
logit_scale
=
logit_scale
self
.
do_tensor_parallel_all_gather
=
(
self
.
do_tensor_parallel_all_gather
=
(
not
skip_all_gather
and
get_tensor_model_parallel_world_size
()
>
1
not
skip_all_gather
and
get_tensor_model_parallel_world_size
()
>
1
)
)
...
@@ -240,6 +243,9 @@ class LogitsProcessor(nn.Module):
...
@@ -240,6 +243,9 @@ class LogitsProcessor(nn.Module):
all_logits
=
self
.
_get_logits
(
states
,
lm_head
)
all_logits
=
self
.
_get_logits
(
states
,
lm_head
)
if
self
.
do_tensor_parallel_all_gather
:
if
self
.
do_tensor_parallel_all_gather
:
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
# The LM head's weights may be zero-padded for parallelism. Remove any
# extra logits that this padding may have produced.
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
...
@@ -302,6 +308,10 @@ class LogitsProcessor(nn.Module):
...
@@ -302,6 +308,10 @@ class LogitsProcessor(nn.Module):
else
:
else
:
# GGUF models
# GGUF models
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
embedding_bias
)
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
embedding_bias
)
# Optional scaling factor, backported from vLLM 0.4
if
self
.
logit_scale
is
not
None
:
logits
.
mul_
(
self
.
logit_scale
)
# In-place multiply
return
logits
return
logits
...
...
python/sglang/srt/models/granite.py
0 → 100644
View file @
993956c6
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only Granite model compatible with HuggingFace weights."""
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
GraniteConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
class
GraniteMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
GraniteAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GraniteConfig
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
layer_id
:
int
=
0
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_is_neox_style
:
bool
=
True
,
max_position_embeddings
:
int
=
8192
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
self
.
hidden_size
//
self
.
total_num_heads
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
config
.
attention_multiplier
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
rope_is_neox_style
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
GraniteDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GraniteConfig
,
layer_id
:
int
=
0
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
residual_multiplier
=
config
.
residual_multiplier
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
getattr
(
config
,
"original_max_position_embeddings"
,
None
):
rope_scaling
[
"original_max_position_embeddings"
]
=
(
config
.
original_max_position_embeddings
)
rope_is_neox_style
=
getattr
(
config
,
"rope_is_neox_style"
,
True
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
self_attn
=
GraniteAttention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
layer_id
=
layer_id
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_is_neox_style
=
rope_is_neox_style
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
mlp
=
GraniteMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
(
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
)
*
self
.
residual_multiplier
)
# multiplier for Maximal Update Parameterization
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
*
self
.
residual_multiplier
return
hidden_states
,
residual
class
GraniteModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GraniteConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
(
[
GraniteDecoderLayer
(
config
,
i
,
quant_config
=
quant_config
,
prefix
=
f
"model.layers.
{
i
}
"
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
residual
=
None
hidden_states
*=
self
.
config
.
embedding_multiplier
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
GraniteForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GraniteConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
GraniteModel
(
config
,
quant_config
=
quant_config
)
# If tie_word_embeddings == True, then input and output embeddings are
# the same tensor. Enforce during object creation so that weights will
# load correctly even if the LM head weights don't have a separate entry
# in the state dict.
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
tie_weights
(
self
.
model
.
embed_tokens
)
# Granite logit scaling factors are applied via division, but
# LogitsProcessor expects a multiplicative factor.
if
hasattr
(
config
,
"logits_scaling"
):
logit_scale
=
1.0
/
config
.
logits_scaling
else
:
logit_scale
=
None
self
.
logits_processor
=
LogitsProcessor
(
config
,
logit_scale
=
logit_scale
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
)
->
LogitsProcessorOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
if
not
get_embedding
:
logits_processor_output
:
LogitsProcessorOutput
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
return
logits_processor_output
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
get_hidden_dim
(
self
,
module_name
):
# return input_dim, output_dim
if
module_name
in
[
"q_proj"
,
"o_proj"
,
"qkv_proj"
]:
return
self
.
config
.
hidden_size
,
self
.
config
.
hidden_size
elif
module_name
in
[
"kv_proj"
]:
return
self
.
config
.
hidden_size
,
self
.
config
.
hidden_size
//
(
self
.
config
.
num_attention_heads
//
self
.
config
.
num_key_value_heads
)
elif
module_name
==
"gate_up_proj"
:
return
self
.
config
.
hidden_size
,
self
.
config
.
intermediate_size
elif
module_name
==
"down_proj"
:
return
self
.
config
.
intermediate_size
,
self
.
config
.
hidden_size
else
:
raise
NotImplementedError
()
def
get_module_name
(
self
,
name
):
params_mapping
=
{
"q_proj"
:
"qkv_proj"
,
"k_proj"
:
"qkv_proj"
,
"v_proj"
:
"qkv_proj"
,
"gate_proj"
:
"gate_up_proj"
,
"up_proj"
:
"gate_up_proj"
,
}
return
params_mapping
.
get
(
name
,
name
)
def
get_module_name_from_weight_name
(
self
,
name
):
for
param_name
,
weight_name
,
shard_id
,
num_shard
in
self
.
stacked_params_mapping
:
if
weight_name
in
name
:
return
(
name
.
replace
(
weight_name
,
param_name
)[:
-
len
(
".weight"
)],
num_shard
,
)
return
name
[:
-
len
(
".weight"
)],
1
def
get_num_params
(
self
):
params_dict
=
dict
(
self
.
named_parameters
())
return
len
(
params_dict
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
if
"lm_head.weight"
in
name
and
self
.
config
.
tie_word_embeddings
:
# Input and output embeddings are tied, so the output embeddings
# may not be present in the checkpoint. We assume that the input
# embeddings are always present in the checkpoint.
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# This block only runs if the preceding for loop doesn't find
# a match for `name` in `stacked_params_mapping`.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip loading kv_scale from ckpts towards new design.
if
name
.
endswith
(
".kv_scale"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
get_weights_by_name
(
self
,
name
:
str
,
truncate_size
:
int
=
100
,
tp_size
:
int
=
1
)
->
Optional
[
torch
.
Tensor
]:
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
Only used for unit test with an unoptimized performance.
For optimized performance, please use torch.save and torch.load.
"""
try
:
if
name
==
"lm_head.weight"
and
self
.
config
.
tie_word_embeddings
:
logger
.
info
(
"word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
)
return
(
self
.
model
.
embed_tokens
.
weight
.
cpu
()
.
to
(
torch
.
float32
)
.
numpy
()
.
tolist
()[:
truncate_size
]
)
mapped_name
=
name
mapped_shard_id
=
None
for
param_name
,
weight_name
,
shard_id
in
self
.
stacked_params_mapping
:
if
weight_name
in
name
:
mapped_name
=
name
.
replace
(
weight_name
,
param_name
)
mapped_shard_id
=
shard_id
break
params_dict
=
dict
(
self
.
named_parameters
())
param
=
params_dict
[
mapped_name
]
if
mapped_shard_id
is
not
None
:
if
mapped_shard_id
in
[
"q"
,
"k"
,
"v"
]:
num_heads
=
self
.
config
.
num_attention_heads
//
tp_size
num_kv_heads
=
self
.
config
.
num_key_value_heads
//
tp_size
head_dim
=
(
self
.
config
.
hidden_size
//
self
.
config
.
num_attention_heads
)
if
mapped_shard_id
==
"q"
:
offset
=
0
size
=
num_heads
*
head_dim
elif
mapped_shard_id
==
"k"
:
offset
=
num_heads
*
head_dim
size
=
num_kv_heads
*
head_dim
elif
mapped_shard_id
==
"v"
:
offset
=
(
num_heads
+
num_kv_heads
)
*
head_dim
size
=
num_kv_heads
*
head_dim
weight
=
param
.
data
.
narrow
(
0
,
offset
,
size
)
elif
mapped_shard_id
in
[
0
,
1
]:
intermediate_size
=
self
.
config
.
intermediate_size
slice_size
=
intermediate_size
//
tp_size
if
mapped_shard_id
==
0
:
# gate_proj
offset
=
0
size
=
slice_size
elif
mapped_shard_id
==
1
:
# up_proj
offset
=
slice_size
size
=
slice_size
weight
=
param
.
data
.
narrow
(
0
,
offset
,
size
)
else
:
weight
=
param
.
data
else
:
weight
=
param
.
data
if
tp_size
>
1
and
(
"o_proj"
in
name
or
"down_proj"
in
name
):
gathered_weights
=
[
torch
.
zeros_like
(
weight
)
for
_
in
range
(
tp_size
)]
torch
.
distributed
.
all_gather
(
gathered_weights
,
weight
)
weight
=
torch
.
cat
(
gathered_weights
,
dim
=
1
)
return
weight
.
cpu
().
to
(
torch
.
float32
).
numpy
().
tolist
()[:
truncate_size
]
except
Exception
:
logger
.
error
(
f
"Error getting weights by name
{
name
}
in GraniteForCausalLM:
{
get_exception_traceback
()
}
"
)
return
None
EntryClass
=
[
GraniteForCausalLM
]
test/srt/models/test_generation_models.py
View file @
993956c6
...
@@ -57,6 +57,7 @@ ALL_OTHER_MODELS = [
...
@@ -57,6 +57,7 @@ ALL_OTHER_MODELS = [
ModelCase
(
"openai-community/gpt2"
),
ModelCase
(
"openai-community/gpt2"
),
ModelCase
(
"microsoft/Phi-3-small-8k-instruct"
),
ModelCase
(
"microsoft/Phi-3-small-8k-instruct"
),
ModelCase
(
"allenai/OLMo-2-1124-7B-Instruct"
,
skip_long_prompt
=
True
),
ModelCase
(
"allenai/OLMo-2-1124-7B-Instruct"
,
skip_long_prompt
=
True
),
ModelCase
(
"ibm-granite/granite-3.0-2b-instruct"
,
skip_long_prompt
=
True
),
]
]
TORCH_DTYPES
=
[
torch
.
float16
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment