Unverified Commit d14e98d9 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Model] Support GGUF models newly added in `transformers` 4.46.0 (#9685)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 9597a095
...@@ -3,27 +3,20 @@ from huggingface_hub import hf_hub_download ...@@ -3,27 +3,20 @@ from huggingface_hub import hf_hub_download
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
def run_gguf_inference(model_path): def run_gguf_inference(model_path, tokenizer):
PROMPT_TEMPLATE = "<|system|>\n{system_message}</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n" # noqa: E501
system_message = "You are a friendly chatbot who always responds in the style of a pirate." # noqa: E501
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
"How many helicopters can a human eat in one sitting?", "How many helicopters can a human eat in one sitting?",
"What's the future of AI?", "What's the future of AI?",
] ]
prompts = [ prompts = [[{"role": "user", "content": prompt}] for prompt in prompts]
PROMPT_TEMPLATE.format(system_message=system_message, prompt=prompt)
for prompt in prompts
]
# Create a sampling params object. # Create a sampling params object.
sampling_params = SamplingParams(temperature=0, max_tokens=128) sampling_params = SamplingParams(temperature=0, max_tokens=128)
# Create an LLM. # Create an LLM.
llm = LLM(model=model_path, llm = LLM(model=model_path, tokenizer=tokenizer)
tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
gpu_memory_utilization=0.95)
outputs = llm.generate(prompts, sampling_params) outputs = llm.chat(prompts, sampling_params)
# Print the outputs. # Print the outputs.
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
...@@ -32,7 +25,8 @@ def run_gguf_inference(model_path): ...@@ -32,7 +25,8 @@ def run_gguf_inference(model_path):
if __name__ == "__main__": if __name__ == "__main__":
repo_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" repo_id = "bartowski/Phi-3-medium-4k-instruct-GGUF"
filename = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" filename = "Phi-3-medium-4k-instruct-IQ2_M.gguf"
tokenizer = "microsoft/Phi-3-medium-4k-instruct"
model = hf_hub_download(repo_id, filename=filename) model = hf_hub_download(repo_id, filename=filename)
run_gguf_inference(model) run_gguf_inference(model, tokenizer)
...@@ -4,6 +4,7 @@ Note: To pass the test, quantization higher than Q4 should be used ...@@ -4,6 +4,7 @@ Note: To pass the test, quantization higher than Q4 should be used
""" """
import os import os
from typing import List, NamedTuple, Type
import pytest import pytest
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
...@@ -11,6 +12,7 @@ from transformers import AutoTokenizer ...@@ -11,6 +12,7 @@ from transformers import AutoTokenizer
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from ....conftest import VllmRunner
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
os.environ["TOKENIZERS_PARALLELISM"] = "true" os.environ["TOKENIZERS_PARALLELISM"] = "true"
...@@ -18,31 +20,74 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true" ...@@ -18,31 +20,74 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
MAX_MODEL_LEN = 1024 MAX_MODEL_LEN = 1024
class GGUFTestConfig(NamedTuple):
original_model: str
gguf_repo: str
gguf_filename: str
@property
def gguf_model(self):
return hf_hub_download(self.gguf_repo, filename=self.gguf_filename)
LLAMA_CONFIG = GGUFTestConfig(
original_model="meta-llama/Llama-3.2-1B-Instruct",
gguf_repo="bartowski/Llama-3.2-1B-Instruct-GGUF",
gguf_filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf",
)
QWEN2_CONFIG = GGUFTestConfig(
original_model="Qwen/Qwen2.5-1.5B-Instruct",
gguf_repo="Qwen/Qwen2.5-1.5B-Instruct-GGUF",
gguf_filename="qwen2.5-1.5b-instruct-q6_k.gguf",
)
PHI3_CONFIG = GGUFTestConfig(
original_model="microsoft/Phi-3.5-mini-instruct",
gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF",
gguf_filename="Phi-3.5-mini-instruct-IQ4_XS.gguf",
)
GPT2_CONFIG = GGUFTestConfig(
original_model="openai-community/gpt2-large",
gguf_repo="QuantFactory/gpt2-large-GGUF",
gguf_filename="gpt2-large.Q4_K_M.gguf",
)
STABLELM_CONFIG = GGUFTestConfig(
original_model="stabilityai/stablelm-3b-4e1t",
gguf_repo="afrideva/stablelm-3b-4e1t-GGUF",
gguf_filename="stablelm-3b-4e1t.q4_k_m.gguf",
)
STARCODER_CONFIG = GGUFTestConfig(
original_model="bigcode/starcoder2-3b",
gguf_repo="QuantFactory/starcoder2-3b-GGUF",
gguf_filename="starcoder2-3b.Q6_K.gguf",
)
MODELS = [
LLAMA_CONFIG,
QWEN2_CONFIG,
PHI3_CONFIG,
GPT2_CONFIG,
STABLELM_CONFIG,
# STARCODER_CONFIG, # broken
]
@pytest.mark.skipif(not is_quant_method_supported("gguf"), @pytest.mark.skipif(not is_quant_method_supported("gguf"),
reason="gguf is not supported on this GPU type.") reason="gguf is not supported on this GPU type.")
@pytest.mark.parametrize(("original_model", "gguf_id", "gguf_path"), [ @pytest.mark.parametrize("model", MODELS)
("meta-llama/Llama-3.2-1B-Instruct",
"bartowski/Llama-3.2-1B-Instruct-GGUF",
"Llama-3.2-1B-Instruct-Q4_K_M.gguf"),
("meta-llama/Llama-3.2-1B-Instruct",
"bartowski/Llama-3.2-1B-Instruct-GGUF",
"Llama-3.2-1B-Instruct-IQ4_XS.gguf"),
("Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen2-1.5B-Instruct-GGUF",
"qwen2-1_5b-instruct-q4_k_m.gguf"),
("Qwen/Qwen2-1.5B-Instruct", "legraphista/Qwen2-1.5B-Instruct-IMat-GGUF",
"Qwen2-1.5B-Instruct.IQ4_XS.gguf"),
])
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("tp_size", [1, 2])
def test_models( def test_models(
num_gpus_available, num_gpus_available: int,
vllm_runner, vllm_runner: Type[VllmRunner],
example_prompts, example_prompts: List[str],
original_model, model: GGUFTestConfig,
gguf_id,
gguf_path,
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
...@@ -51,28 +96,26 @@ def test_models( ...@@ -51,28 +96,26 @@ def test_models(
if num_gpus_available < tp_size: if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
gguf_model = hf_hub_download(gguf_id, filename=gguf_path) tokenizer = AutoTokenizer.from_pretrained(model.original_model)
if tokenizer.chat_template is not None:
tokenizer = AutoTokenizer.from_pretrained(original_model) messages = [[{
messages = [[{ 'role': 'user',
'role': 'user', 'content': prompt
'content': prompt }] for prompt in example_prompts]
}] for prompt in example_prompts] example_prompts = tokenizer.apply_chat_template(
example_prompts = tokenizer.apply_chat_template(messages, messages, tokenize=False, add_generation_prompt=True)
tokenize=False,
add_generation_prompt=True)
# Run unquantized model. # Run unquantized model.
with vllm_runner(model_name=original_model, with vllm_runner(model_name=model.original_model,
dtype=dtype, dtype=dtype,
max_model_len=MAX_MODEL_LEN, max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tp_size) as original_model: tensor_parallel_size=tp_size) as original_model:
original_outputs = original_model.generate_greedy_logprobs( original_outputs = original_model.generate_greedy_logprobs(
example_prompts[:-1], max_tokens, num_logprobs) example_prompts[:-1], max_tokens, num_logprobs)
# Run gguf model. # Run gguf model.
with vllm_runner(model_name=gguf_model, with vllm_runner(model_name=model.gguf_model,
tokenizer_name=model.original_model,
dtype=dtype, dtype=dtype,
max_model_len=MAX_MODEL_LEN, max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tp_size) as gguf_model: tensor_parallel_size=tp_size) as gguf_model:
......
...@@ -447,8 +447,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -447,8 +447,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type: if is_gguf_weight_type:
param.data[loaded_shard_id].copy_(loaded_weight) if loaded_shard_id is not None:
param.shard_weight_type[loaded_shard_id] = loaded_weight.item() param.data[loaded_shard_id].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
else:
param.shard_weight_type = {
i: loaded_weight.item()
for i, _ in enumerate(self.output_sizes)
}
return return
if is_gguf_weight: if is_gguf_weight:
...@@ -459,15 +465,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -459,15 +465,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size = loaded_weight.size(output_dim) // tp_size shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, if loaded_shard_id is not None:
shard_size) loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
param.shard_id.append(loaded_shard_id) param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container) param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight) param.data_container.append(loaded_weight)
if len(param.data_container) == 2: if len(param.data_container) == 2:
self.qweight = param.materialize_nested() self.qweight = param.materialize_nested()
return return
param_data = param.data param_data = param.data
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
...@@ -811,10 +817,16 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -811,10 +817,16 @@ class QKVParallelLinear(ColumnParallelLinear):
# initialize GGUF param after we know the quantize type # initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type and loaded_shard_id is not None: if is_gguf_weight_type:
idx_map = {"q": 0, "k": 1, "v": 2} idx_map = {"q": 0, "k": 1, "v": 2}
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight) if loaded_shard_id is not None:
param.shard_weight_type[loaded_shard_id] = loaded_weight.item() param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
else:
param.shard_weight_type = {
k: loaded_weight.item()
for k in idx_map
}
return return
if is_gguf_weight: if is_gguf_weight:
...@@ -825,15 +837,15 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -825,15 +837,15 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size = loaded_weight.size(output_dim) // tp_size shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, if loaded_shard_id is not None:
shard_size) loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
param.shard_id.append(loaded_shard_id) param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container) param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight) param.data_container.append(loaded_weight)
if len(param.data_container) == 3: if len(param.data_container) == 3:
self.qweight = param.materialize_nested() self.qweight = param.materialize_nested()
return return
param_data = param.data param_data = param.data
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
......
...@@ -198,7 +198,10 @@ class GPT2Model(nn.Module): ...@@ -198,7 +198,10 @@ class GPT2Model(nn.Module):
assert not config.scale_attn_by_inverse_layer_idx assert not config.scale_attn_by_inverse_layer_idx
assert not config.reorder_and_upcast_attn assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wte = VocabParallelEmbedding(config.vocab_size,
self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.wte")
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer, self.h = make_layers( self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
...@@ -259,7 +262,9 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): ...@@ -259,7 +262,9 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
self.lm_head = self.transformer.wte self.lm_head = self.transformer.wte
else: else:
self.lm_head = ParallelLMHead(self.config.vocab_size, self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size) self.config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.lm_head")
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
...@@ -304,7 +309,7 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): ...@@ -304,7 +309,7 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "lm_head.weight" in name: if name.startswith("lm_head"):
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
# linear layer. # linear layer.
continue continue
......
...@@ -156,7 +156,8 @@ class LlamaAttention(nn.Module): ...@@ -156,7 +156,8 @@ class LlamaAttention(nn.Module):
) )
is_neox_style = True is_neox_style = True
if quant_config is not None and quant_config.get_name() == "gguf": is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and config.model_type == "llama":
is_neox_style = False is_neox_style = False
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
......
...@@ -22,7 +22,7 @@ from typing import Iterable, List, Optional, Set, Tuple, Union ...@@ -22,7 +22,7 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import StableLmConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
...@@ -50,8 +50,9 @@ from .utils import (is_pp_missing_parameter, ...@@ -50,8 +50,9 @@ from .utils import (is_pp_missing_parameter,
class StablelmMLP(nn.Module): class StablelmMLP(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: StableLmConfig,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -59,10 +60,13 @@ class StablelmMLP(nn.Module): ...@@ -59,10 +60,13 @@ class StablelmMLP(nn.Module):
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
config.hidden_size, [config.intermediate_size] * 2, config.hidden_size, [config.intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(config.intermediate_size, self.down_proj = RowParallelLinear(config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=False) bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -75,7 +79,7 @@ class StablelmMLP(nn.Module): ...@@ -75,7 +79,7 @@ class StablelmMLP(nn.Module):
class StablelmAttention(nn.Module): class StablelmAttention(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: StableLmConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None: prefix: str = "") -> None:
...@@ -116,11 +120,13 @@ class StablelmAttention(nn.Module): ...@@ -116,11 +120,13 @@ class StablelmAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_key_value_heads, self.total_num_key_value_heads,
self.qkv_bias, self.qkv_bias,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.o_proj")
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.rotary_ndims, rotary_dim=self.rotary_ndims,
...@@ -154,7 +160,7 @@ class StablelmDecoderLayer(nn.Module): ...@@ -154,7 +160,7 @@ class StablelmDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: StableLmConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
...@@ -164,7 +170,7 @@ class StablelmDecoderLayer(nn.Module): ...@@ -164,7 +170,7 @@ class StablelmDecoderLayer(nn.Module):
cache_config, cache_config,
quant_config, quant_config,
prefix=f"{prefix}.self_attn") prefix=f"{prefix}.self_attn")
self.mlp = StablelmMLP(config, quant_config) self.mlp = StablelmMLP(config, quant_config, prefix=f"{prefix}.mlp")
norm_eps = getattr(config, "norm_eps", norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05)) getattr(config, "layer_norm_eps", 1e-05))
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
...@@ -210,6 +216,8 @@ class StableLMEpochModel(nn.Module): ...@@ -210,6 +216,8 @@ class StableLMEpochModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
...@@ -270,7 +278,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -270,7 +278,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.lm_head")
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
......
...@@ -88,12 +88,14 @@ class Starcoder2Attention(nn.Module): ...@@ -88,12 +88,14 @@ class Starcoder2Attention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=self.use_bias, bias=self.use_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=self.use_bias, bias=self.use_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj",
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -129,19 +131,22 @@ class Starcoder2MLP(nn.Module): ...@@ -129,19 +131,22 @@ class Starcoder2MLP(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.c_fc = ColumnParallelLinear( self.c_fc = ColumnParallelLinear(
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
bias=config.use_bias, bias=config.use_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_fc",
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=config.use_bias, bias=config.use_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj",
) )
self.act = get_act_fn(config.hidden_act) self.act = get_act_fn(config.hidden_act)
...@@ -165,7 +170,9 @@ class Starcoder2DecoderLayer(nn.Module): ...@@ -165,7 +170,9 @@ class Starcoder2DecoderLayer(nn.Module):
cache_config, cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn") prefix=f"{prefix}.self_attn")
self.mlp = Starcoder2MLP(config, quant_config=quant_config) self.mlp = Starcoder2MLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon) eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
...@@ -213,8 +220,11 @@ class Starcoder2Model(nn.Module): ...@@ -213,8 +220,11 @@ class Starcoder2Model(nn.Module):
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
# TODO: consider padding_idx (currently removed) # TODO: consider padding_idx (currently removed)
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(
config.hidden_size) config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens")
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: Starcoder2DecoderLayer( lambda prefix: Starcoder2DecoderLayer(
...@@ -279,6 +289,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ...@@ -279,6 +289,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.lm_head",
) )
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment