Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
993956c6
"sgl-router/vscode:/vscode.git/clone" did not exist on "e7387035476eb2c57fd49608066abf2e5f7551ac"
Unverified
Commit
993956c6
authored
Dec 11, 2024
by
Fred Reiss
Committed by
GitHub
Dec 11, 2024
Browse files
Add support for IBM Granite 3.x models (#2437)
parent
f8548295
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
562 additions
and
1 deletion
+562
-1
docs/references/supported_models.md
docs/references/supported_models.md
+1
-0
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+32
-0
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+11
-1
python/sglang/srt/models/granite.py
python/sglang/srt/models/granite.py
+517
-0
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+1
-0
No files found.
docs/references/supported_models.md
View file @
993956c6
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
-
SmolLM
-
SmolLM
-
GLM-4
-
GLM-4
-
Phi-3-Small
-
Phi-3-Small
-
IBM Granite 3
## Embedding Models
## Embedding Models
...
...
python/sglang/lang/chat_template.py
View file @
993956c6
...
@@ -320,6 +320,28 @@ register_chat_template(
...
@@ -320,6 +320,28 @@ register_chat_template(
)
)
)
)
register_chat_template
(
ChatTemplate
(
name
=
"granite-3-instruct"
,
default_system_prompt
=
None
,
role_prefix_and_suffix
=
{
"system"
:
(
"<|start_of_role|>system<|end_of_role|>"
,
"<|end_of_text|>"
,
),
"user"
:
(
"<|start_of_role|>user<|end_of_role|>"
,
"<|end_of_text|>"
,
),
"assistant"
:
(
"<|start_of_role|>assistant<|end_of_role|>"
,
"<|end_of_text|>"
,
),
},
stop_str
=
(
"<|end_of_text|>"
,),
)
)
@
register_chat_template_matching_function
@
register_chat_template_matching_function
def
match_dbrx
(
model_path
:
str
):
def
match_dbrx
(
model_path
:
str
):
...
@@ -402,6 +424,16 @@ def match_c4ai_command_r(model_path: str):
...
@@ -402,6 +424,16 @@ def match_c4ai_command_r(model_path: str):
return
get_chat_template
(
"c4ai-command-r"
)
return
get_chat_template
(
"c4ai-command-r"
)
@
register_chat_template_matching_function
def
match_granite_instruct
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
# When future versions of Granite are released, this code may
# need to be updated. For now, assume that the Granite 3.0
# template works across the board.
if
"granite"
in
model_path
and
"instruct"
in
model_path
:
return
get_chat_template
(
"granite-3-instruct"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
messages
=
[
messages
=
[
{
"role"
:
"system"
,
"content"
:
None
},
# None means default
{
"role"
:
"system"
,
"content"
:
None
},
# None means default
...
...
python/sglang/srt/layers/logits_processor.py
View file @
993956c6
...
@@ -91,9 +91,12 @@ class LogitsMetadata:
...
@@ -91,9 +91,12 @@ class LogitsMetadata:
class
LogitsProcessor
(
nn
.
Module
):
class
LogitsProcessor
(
nn
.
Module
):
def
__init__
(
self
,
config
,
skip_all_gather
:
bool
=
False
):
def
__init__
(
self
,
config
,
skip_all_gather
:
bool
=
False
,
logit_scale
:
Optional
[
float
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
logit_scale
=
logit_scale
self
.
do_tensor_parallel_all_gather
=
(
self
.
do_tensor_parallel_all_gather
=
(
not
skip_all_gather
and
get_tensor_model_parallel_world_size
()
>
1
not
skip_all_gather
and
get_tensor_model_parallel_world_size
()
>
1
)
)
...
@@ -240,6 +243,9 @@ class LogitsProcessor(nn.Module):
...
@@ -240,6 +243,9 @@ class LogitsProcessor(nn.Module):
all_logits
=
self
.
_get_logits
(
states
,
lm_head
)
all_logits
=
self
.
_get_logits
(
states
,
lm_head
)
if
self
.
do_tensor_parallel_all_gather
:
if
self
.
do_tensor_parallel_all_gather
:
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
# The LM head's weights may be zero-padded for parallelism. Remove any
# extra logits that this padding may have produced.
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
...
@@ -302,6 +308,10 @@ class LogitsProcessor(nn.Module):
...
@@ -302,6 +308,10 @@ class LogitsProcessor(nn.Module):
else
:
else
:
# GGUF models
# GGUF models
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
embedding_bias
)
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
embedding_bias
)
# Optional scaling factor, backported from vLLM 0.4
if
self
.
logit_scale
is
not
None
:
logits
.
mul_
(
self
.
logit_scale
)
# In-place multiply
return
logits
return
logits
...
...
python/sglang/srt/models/granite.py
0 → 100644
View file @
993956c6
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only Granite model compatible with HuggingFace weights."""
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
GraniteConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
class
GraniteMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
GraniteAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GraniteConfig
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
layer_id
:
int
=
0
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_is_neox_style
:
bool
=
True
,
max_position_embeddings
:
int
=
8192
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
self
.
hidden_size
//
self
.
total_num_heads
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
config
.
attention_multiplier
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
rope_is_neox_style
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
GraniteDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GraniteConfig
,
layer_id
:
int
=
0
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
residual_multiplier
=
config
.
residual_multiplier
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
getattr
(
config
,
"original_max_position_embeddings"
,
None
):
rope_scaling
[
"original_max_position_embeddings"
]
=
(
config
.
original_max_position_embeddings
)
rope_is_neox_style
=
getattr
(
config
,
"rope_is_neox_style"
,
True
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
self_attn
=
GraniteAttention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
layer_id
=
layer_id
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_is_neox_style
=
rope_is_neox_style
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
mlp
=
GraniteMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
(
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
)
*
self
.
residual_multiplier
)
# multiplier for Maximal Update Parameterization
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
*
self
.
residual_multiplier
return
hidden_states
,
residual
class
GraniteModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GraniteConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
(
[
GraniteDecoderLayer
(
config
,
i
,
quant_config
=
quant_config
,
prefix
=
f
"model.layers.
{
i
}
"
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
residual
=
None
hidden_states
*=
self
.
config
.
embedding_multiplier
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
GraniteForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GraniteConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
GraniteModel
(
config
,
quant_config
=
quant_config
)
# If tie_word_embeddings == True, then input and output embeddings are
# the same tensor. Enforce during object creation so that weights will
# load correctly even if the LM head weights don't have a separate entry
# in the state dict.
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
tie_weights
(
self
.
model
.
embed_tokens
)
# Granite logit scaling factors are applied via division, but
# LogitsProcessor expects a multiplicative factor.
if
hasattr
(
config
,
"logits_scaling"
):
logit_scale
=
1.0
/
config
.
logits_scaling
else
:
logit_scale
=
None
self
.
logits_processor
=
LogitsProcessor
(
config
,
logit_scale
=
logit_scale
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
)
->
LogitsProcessorOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
if
not
get_embedding
:
logits_processor_output
:
LogitsProcessorOutput
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
return
logits_processor_output
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
get_hidden_dim
(
self
,
module_name
):
# return input_dim, output_dim
if
module_name
in
[
"q_proj"
,
"o_proj"
,
"qkv_proj"
]:
return
self
.
config
.
hidden_size
,
self
.
config
.
hidden_size
elif
module_name
in
[
"kv_proj"
]:
return
self
.
config
.
hidden_size
,
self
.
config
.
hidden_size
//
(
self
.
config
.
num_attention_heads
//
self
.
config
.
num_key_value_heads
)
elif
module_name
==
"gate_up_proj"
:
return
self
.
config
.
hidden_size
,
self
.
config
.
intermediate_size
elif
module_name
==
"down_proj"
:
return
self
.
config
.
intermediate_size
,
self
.
config
.
hidden_size
else
:
raise
NotImplementedError
()
def
get_module_name
(
self
,
name
):
params_mapping
=
{
"q_proj"
:
"qkv_proj"
,
"k_proj"
:
"qkv_proj"
,
"v_proj"
:
"qkv_proj"
,
"gate_proj"
:
"gate_up_proj"
,
"up_proj"
:
"gate_up_proj"
,
}
return
params_mapping
.
get
(
name
,
name
)
def
get_module_name_from_weight_name
(
self
,
name
):
for
param_name
,
weight_name
,
shard_id
,
num_shard
in
self
.
stacked_params_mapping
:
if
weight_name
in
name
:
return
(
name
.
replace
(
weight_name
,
param_name
)[:
-
len
(
".weight"
)],
num_shard
,
)
return
name
[:
-
len
(
".weight"
)],
1
def
get_num_params
(
self
):
params_dict
=
dict
(
self
.
named_parameters
())
return
len
(
params_dict
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
if
"lm_head.weight"
in
name
and
self
.
config
.
tie_word_embeddings
:
# Input and output embeddings are tied, so the output embeddings
# may not be present in the checkpoint. We assume that the input
# embeddings are always present in the checkpoint.
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# This block only runs if the preceding for loop doesn't find
# a match for `name` in `stacked_params_mapping`.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip loading kv_scale from ckpts towards new design.
if
name
.
endswith
(
".kv_scale"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
get_weights_by_name
(
self
,
name
:
str
,
truncate_size
:
int
=
100
,
tp_size
:
int
=
1
)
->
Optional
[
torch
.
Tensor
]:
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
Only used for unit test with an unoptimized performance.
For optimized performance, please use torch.save and torch.load.
"""
try
:
if
name
==
"lm_head.weight"
and
self
.
config
.
tie_word_embeddings
:
logger
.
info
(
"word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
)
return
(
self
.
model
.
embed_tokens
.
weight
.
cpu
()
.
to
(
torch
.
float32
)
.
numpy
()
.
tolist
()[:
truncate_size
]
)
mapped_name
=
name
mapped_shard_id
=
None
for
param_name
,
weight_name
,
shard_id
in
self
.
stacked_params_mapping
:
if
weight_name
in
name
:
mapped_name
=
name
.
replace
(
weight_name
,
param_name
)
mapped_shard_id
=
shard_id
break
params_dict
=
dict
(
self
.
named_parameters
())
param
=
params_dict
[
mapped_name
]
if
mapped_shard_id
is
not
None
:
if
mapped_shard_id
in
[
"q"
,
"k"
,
"v"
]:
num_heads
=
self
.
config
.
num_attention_heads
//
tp_size
num_kv_heads
=
self
.
config
.
num_key_value_heads
//
tp_size
head_dim
=
(
self
.
config
.
hidden_size
//
self
.
config
.
num_attention_heads
)
if
mapped_shard_id
==
"q"
:
offset
=
0
size
=
num_heads
*
head_dim
elif
mapped_shard_id
==
"k"
:
offset
=
num_heads
*
head_dim
size
=
num_kv_heads
*
head_dim
elif
mapped_shard_id
==
"v"
:
offset
=
(
num_heads
+
num_kv_heads
)
*
head_dim
size
=
num_kv_heads
*
head_dim
weight
=
param
.
data
.
narrow
(
0
,
offset
,
size
)
elif
mapped_shard_id
in
[
0
,
1
]:
intermediate_size
=
self
.
config
.
intermediate_size
slice_size
=
intermediate_size
//
tp_size
if
mapped_shard_id
==
0
:
# gate_proj
offset
=
0
size
=
slice_size
elif
mapped_shard_id
==
1
:
# up_proj
offset
=
slice_size
size
=
slice_size
weight
=
param
.
data
.
narrow
(
0
,
offset
,
size
)
else
:
weight
=
param
.
data
else
:
weight
=
param
.
data
if
tp_size
>
1
and
(
"o_proj"
in
name
or
"down_proj"
in
name
):
gathered_weights
=
[
torch
.
zeros_like
(
weight
)
for
_
in
range
(
tp_size
)]
torch
.
distributed
.
all_gather
(
gathered_weights
,
weight
)
weight
=
torch
.
cat
(
gathered_weights
,
dim
=
1
)
return
weight
.
cpu
().
to
(
torch
.
float32
).
numpy
().
tolist
()[:
truncate_size
]
except
Exception
:
logger
.
error
(
f
"Error getting weights by name
{
name
}
in GraniteForCausalLM:
{
get_exception_traceback
()
}
"
)
return
None
EntryClass
=
[
GraniteForCausalLM
]
test/srt/models/test_generation_models.py
View file @
993956c6
...
@@ -57,6 +57,7 @@ ALL_OTHER_MODELS = [
...
@@ -57,6 +57,7 @@ ALL_OTHER_MODELS = [
ModelCase
(
"openai-community/gpt2"
),
ModelCase
(
"openai-community/gpt2"
),
ModelCase
(
"microsoft/Phi-3-small-8k-instruct"
),
ModelCase
(
"microsoft/Phi-3-small-8k-instruct"
),
ModelCase
(
"allenai/OLMo-2-1124-7B-Instruct"
,
skip_long_prompt
=
True
),
ModelCase
(
"allenai/OLMo-2-1124-7B-Instruct"
,
skip_long_prompt
=
True
),
ModelCase
(
"ibm-granite/granite-3.0-2b-instruct"
,
skip_long_prompt
=
True
),
]
]
TORCH_DTYPES
=
[
torch
.
float16
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment