Unverified Commit 993956c6 authored by Fred Reiss's avatar Fred Reiss Committed by GitHub
Browse files

Add support for IBM Granite 3.x models (#2437)

parent f8548295
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
- SmolLM - SmolLM
- GLM-4 - GLM-4
- Phi-3-Small - Phi-3-Small
- IBM Granite 3
## Embedding Models ## Embedding Models
......
...@@ -320,6 +320,28 @@ register_chat_template( ...@@ -320,6 +320,28 @@ register_chat_template(
) )
) )
register_chat_template(
ChatTemplate(
name="granite-3-instruct",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|start_of_role|>system<|end_of_role|>",
"<|end_of_text|>",
),
"user": (
"<|start_of_role|>user<|end_of_role|>",
"<|end_of_text|>",
),
"assistant": (
"<|start_of_role|>assistant<|end_of_role|>",
"<|end_of_text|>",
),
},
stop_str=("<|end_of_text|>",),
)
)
@register_chat_template_matching_function @register_chat_template_matching_function
def match_dbrx(model_path: str): def match_dbrx(model_path: str):
...@@ -402,6 +424,16 @@ def match_c4ai_command_r(model_path: str): ...@@ -402,6 +424,16 @@ def match_c4ai_command_r(model_path: str):
return get_chat_template("c4ai-command-r") return get_chat_template("c4ai-command-r")
@register_chat_template_matching_function
def match_granite_instruct(model_path: str):
model_path = model_path.lower()
# When future versions of Granite are released, this code may
# need to be updated. For now, assume that the Granite 3.0
# template works across the board.
if "granite" in model_path and "instruct" in model_path:
return get_chat_template("granite-3-instruct")
if __name__ == "__main__": if __name__ == "__main__":
messages = [ messages = [
{"role": "system", "content": None}, # None means default {"role": "system", "content": None}, # None means default
......
...@@ -91,9 +91,12 @@ class LogitsMetadata: ...@@ -91,9 +91,12 @@ class LogitsMetadata:
class LogitsProcessor(nn.Module): class LogitsProcessor(nn.Module):
def __init__(self, config, skip_all_gather: bool = False): def __init__(
self, config, skip_all_gather: bool = False, logit_scale: Optional[float] = None
):
super().__init__() super().__init__()
self.config = config self.config = config
self.logit_scale = logit_scale
self.do_tensor_parallel_all_gather = ( self.do_tensor_parallel_all_gather = (
not skip_all_gather and get_tensor_model_parallel_world_size() > 1 not skip_all_gather and get_tensor_model_parallel_world_size() > 1
) )
...@@ -240,6 +243,9 @@ class LogitsProcessor(nn.Module): ...@@ -240,6 +243,9 @@ class LogitsProcessor(nn.Module):
all_logits = self._get_logits(states, lm_head) all_logits = self._get_logits(states, lm_head)
if self.do_tensor_parallel_all_gather: if self.do_tensor_parallel_all_gather:
all_logits = tensor_model_parallel_all_gather(all_logits) all_logits = tensor_model_parallel_all_gather(all_logits)
# The LM head's weights may be zero-padded for parallelism. Remove any
# extra logits that this padding may have produced.
all_logits = all_logits[:, : self.config.vocab_size].float() all_logits = all_logits[:, : self.config.vocab_size].float()
if hasattr(self.config, "final_logit_softcapping"): if hasattr(self.config, "final_logit_softcapping"):
...@@ -302,6 +308,10 @@ class LogitsProcessor(nn.Module): ...@@ -302,6 +308,10 @@ class LogitsProcessor(nn.Module):
else: else:
# GGUF models # GGUF models
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
# Optional scaling factor, backported from vLLM 0.4
if self.logit_scale is not None:
logits.mul_(self.logit_scale) # In-place multiply
return logits return logits
......
This diff is collapsed.
...@@ -57,6 +57,7 @@ ALL_OTHER_MODELS = [ ...@@ -57,6 +57,7 @@ ALL_OTHER_MODELS = [
ModelCase("openai-community/gpt2"), ModelCase("openai-community/gpt2"),
ModelCase("microsoft/Phi-3-small-8k-instruct"), ModelCase("microsoft/Phi-3-small-8k-instruct"),
ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True), ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True),
] ]
TORCH_DTYPES = [torch.float16] TORCH_DTYPES = [torch.float16]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment