Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
993956c6
Unverified
Commit
993956c6
authored
Dec 11, 2024
by
Fred Reiss
Committed by
GitHub
Dec 11, 2024
Browse files
Add support for IBM Granite 3.x models (#2437)
parent
f8548295
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
562 additions
and
1 deletion
+562
-1
docs/references/supported_models.md
docs/references/supported_models.md
+1
-0
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+32
-0
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+11
-1
python/sglang/srt/models/granite.py
python/sglang/srt/models/granite.py
+517
-0
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+1
-0
No files found.
docs/references/supported_models.md
View file @
993956c6
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
-
SmolLM
-
SmolLM
-
GLM-4
-
GLM-4
-
Phi-3-Small
-
Phi-3-Small
-
IBM Granite 3
## Embedding Models
## Embedding Models
...
...
python/sglang/lang/chat_template.py
View file @
993956c6
...
@@ -320,6 +320,28 @@ register_chat_template(
...
@@ -320,6 +320,28 @@ register_chat_template(
)
)
)
)
register_chat_template
(
ChatTemplate
(
name
=
"granite-3-instruct"
,
default_system_prompt
=
None
,
role_prefix_and_suffix
=
{
"system"
:
(
"<|start_of_role|>system<|end_of_role|>"
,
"<|end_of_text|>"
,
),
"user"
:
(
"<|start_of_role|>user<|end_of_role|>"
,
"<|end_of_text|>"
,
),
"assistant"
:
(
"<|start_of_role|>assistant<|end_of_role|>"
,
"<|end_of_text|>"
,
),
},
stop_str
=
(
"<|end_of_text|>"
,),
)
)
@
register_chat_template_matching_function
@
register_chat_template_matching_function
def
match_dbrx
(
model_path
:
str
):
def
match_dbrx
(
model_path
:
str
):
...
@@ -402,6 +424,16 @@ def match_c4ai_command_r(model_path: str):
...
@@ -402,6 +424,16 @@ def match_c4ai_command_r(model_path: str):
return
get_chat_template
(
"c4ai-command-r"
)
return
get_chat_template
(
"c4ai-command-r"
)
@
register_chat_template_matching_function
def
match_granite_instruct
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
# When future versions of Granite are released, this code may
# need to be updated. For now, assume that the Granite 3.0
# template works across the board.
if
"granite"
in
model_path
and
"instruct"
in
model_path
:
return
get_chat_template
(
"granite-3-instruct"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
messages
=
[
messages
=
[
{
"role"
:
"system"
,
"content"
:
None
},
# None means default
{
"role"
:
"system"
,
"content"
:
None
},
# None means default
...
...
python/sglang/srt/layers/logits_processor.py
View file @
993956c6
...
@@ -91,9 +91,12 @@ class LogitsMetadata:
...
@@ -91,9 +91,12 @@ class LogitsMetadata:
class
LogitsProcessor
(
nn
.
Module
):
class
LogitsProcessor
(
nn
.
Module
):
def
__init__
(
self
,
config
,
skip_all_gather
:
bool
=
False
):
def
__init__
(
self
,
config
,
skip_all_gather
:
bool
=
False
,
logit_scale
:
Optional
[
float
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
logit_scale
=
logit_scale
self
.
do_tensor_parallel_all_gather
=
(
self
.
do_tensor_parallel_all_gather
=
(
not
skip_all_gather
and
get_tensor_model_parallel_world_size
()
>
1
not
skip_all_gather
and
get_tensor_model_parallel_world_size
()
>
1
)
)
...
@@ -240,6 +243,9 @@ class LogitsProcessor(nn.Module):
...
@@ -240,6 +243,9 @@ class LogitsProcessor(nn.Module):
all_logits
=
self
.
_get_logits
(
states
,
lm_head
)
all_logits
=
self
.
_get_logits
(
states
,
lm_head
)
if
self
.
do_tensor_parallel_all_gather
:
if
self
.
do_tensor_parallel_all_gather
:
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
# The LM head's weights may be zero-padded for parallelism. Remove any
# extra logits that this padding may have produced.
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
...
@@ -302,6 +308,10 @@ class LogitsProcessor(nn.Module):
...
@@ -302,6 +308,10 @@ class LogitsProcessor(nn.Module):
else
:
else
:
# GGUF models
# GGUF models
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
embedding_bias
)
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
embedding_bias
)
# Optional scaling factor, backported from vLLM 0.4
if
self
.
logit_scale
is
not
None
:
logits
.
mul_
(
self
.
logit_scale
)
# In-place multiply
return
logits
return
logits
...
...
python/sglang/srt/models/granite.py
0 → 100644
View file @
993956c6
This diff is collapsed.
Click to expand it.
test/srt/models/test_generation_models.py
View file @
993956c6
...
@@ -57,6 +57,7 @@ ALL_OTHER_MODELS = [
...
@@ -57,6 +57,7 @@ ALL_OTHER_MODELS = [
ModelCase
(
"openai-community/gpt2"
),
ModelCase
(
"openai-community/gpt2"
),
ModelCase
(
"microsoft/Phi-3-small-8k-instruct"
),
ModelCase
(
"microsoft/Phi-3-small-8k-instruct"
),
ModelCase
(
"allenai/OLMo-2-1124-7B-Instruct"
,
skip_long_prompt
=
True
),
ModelCase
(
"allenai/OLMo-2-1124-7B-Instruct"
,
skip_long_prompt
=
True
),
ModelCase
(
"ibm-granite/granite-3.0-2b-instruct"
,
skip_long_prompt
=
True
),
]
]
TORCH_DTYPES
=
[
torch
.
float16
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment