Unverified Commit 4936be8a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Revert "Revert "[FEAT] Support GGUF format"" (#2287)

parent 1bfa511b
...@@ -15,3 +15,4 @@ sphinx-copybutton ...@@ -15,3 +15,4 @@ sphinx-copybutton
sphinx-tabs sphinx-tabs
sphinxcontrib-mermaid sphinxcontrib-mermaid
urllib3<2.0.0 urllib3<2.0.0
gguf>=0.10.0
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import contextlib import contextlib
import os import os
import warnings import warnings
from pathlib import Path
from typing import Dict, Optional, Type, Union from typing import Dict, Optional, Type, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
...@@ -27,6 +28,7 @@ from transformers import ( ...@@ -27,6 +28,7 @@ from transformers import (
PreTrainedTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
) )
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
try: try:
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
...@@ -60,15 +62,29 @@ def get_config( ...@@ -60,15 +62,29 @@ def get_config(
trust_remote_code: bool, trust_remote_code: bool,
revision: Optional[str] = None, revision: Optional[str] = None,
model_override_args: Optional[dict] = None, model_override_args: Optional[dict] = None,
**kwargs,
): ):
is_gguf = check_gguf_file(model)
if is_gguf:
kwargs["gguf_file"] = model
model = Path(model).parent
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
) )
if config.model_type in _CONFIG_REGISTRY: if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type] config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision) config = config_class.from_pretrained(model, revision=revision)
if model_override_args: if model_override_args:
config.update(model_override_args) config.update(model_override_args)
# Special architecture mapping check for GGUF models
if is_gguf:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})
return config return config
...@@ -123,6 +139,11 @@ def get_tokenizer( ...@@ -123,6 +139,11 @@ def get_tokenizer(
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False kwargs["use_fast"] = False
is_gguf = check_gguf_file(tokenizer_name)
if is_gguf:
kwargs["gguf_file"] = tokenizer_name
tokenizer_name = Path(tokenizer_name).parent
try: try:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, tokenizer_name,
...@@ -195,3 +216,16 @@ def attach_additional_stop_token_ids(tokenizer): ...@@ -195,3 +216,16 @@ def attach_additional_stop_token_ids(tokenizer):
) )
else: else:
tokenizer.additional_stop_token_ids = None tokenizer.additional_stop_token_ids = None
def check_gguf_file(model: Union[str, os.PathLike]) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)
if not model.is_file():
return False
elif model.suffix == ".gguf":
return True
with open(model, "rb") as f:
header = f.read(4)
return header == b"GGUF"
...@@ -23,6 +23,7 @@ from vllm.distributed import ( ...@@ -23,6 +23,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
...@@ -163,7 +164,7 @@ class LogitsProcessor(nn.Module): ...@@ -163,7 +164,7 @@ class LogitsProcessor(nn.Module):
self, self,
input_ids, input_ids,
hidden_states, hidden_states,
weight, lm_head: VocabParallelEmbedding,
logits_metadata: Union[LogitsMetadata, ForwardBatch], logits_metadata: Union[LogitsMetadata, ForwardBatch],
): ):
if isinstance(logits_metadata, ForwardBatch): if isinstance(logits_metadata, ForwardBatch):
...@@ -178,7 +179,7 @@ class LogitsProcessor(nn.Module): ...@@ -178,7 +179,7 @@ class LogitsProcessor(nn.Module):
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
last_hidden = hidden_states[last_index] last_hidden = hidden_states[last_index]
last_logits = torch.matmul(last_hidden, weight.T) last_logits = self._get_logits(last_hidden, lm_head)
if self.do_tensor_parallel_all_gather: if self.do_tensor_parallel_all_gather:
last_logits = tensor_model_parallel_all_gather(last_logits) last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size].float() last_logits = last_logits[:, : self.config.vocab_size].float()
...@@ -229,7 +230,7 @@ class LogitsProcessor(nn.Module): ...@@ -229,7 +230,7 @@ class LogitsProcessor(nn.Module):
# Compute the logits and logprobs for all required tokens # Compute the logits and logprobs for all required tokens
states = torch.cat(states, dim=0) states = torch.cat(states, dim=0)
all_logits = torch.matmul(states, weight.T) all_logits = self._get_logits(states, lm_head)
if self.do_tensor_parallel_all_gather: if self.do_tensor_parallel_all_gather:
all_logits = tensor_model_parallel_all_gather(all_logits) all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size].float() all_logits = all_logits[:, : self.config.vocab_size].float()
...@@ -276,6 +277,19 @@ class LogitsProcessor(nn.Module): ...@@ -276,6 +277,19 @@ class LogitsProcessor(nn.Module):
output_top_logprobs=output_top_logprobs, output_top_logprobs=output_top_logprobs,
) )
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if hasattr(lm_head, "weight"):
logits = torch.matmul(hidden_states, lm_head.weight.T)
else:
# GGUF models
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
return logits
def test(): def test():
all_logprobs = torch.tensor( all_logprobs = torch.tensor(
......
...@@ -222,6 +222,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -222,6 +222,7 @@ class VocabParallelEmbedding(torch.nn.Module):
enable_tp: bool = True, enable_tp: bool = True,
): ):
super().__init__() super().__init__()
self.quant_config = quant_config
self.enable_tp = enable_tp self.enable_tp = enable_tp
if self.enable_tp: if self.enable_tp:
......
...@@ -59,6 +59,7 @@ from sglang.srt.utils import ( ...@@ -59,6 +59,7 @@ from sglang.srt.utils import (
enable_show_time_cost, enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
is_hip, is_hip,
monkey_patch_vllm_gguf_config,
monkey_patch_vllm_model_config, monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes, set_cpu_offload_max_bytes,
...@@ -297,6 +298,8 @@ class ModelRunner: ...@@ -297,6 +298,8 @@ class ModelRunner:
download_dir=self.server_args.download_dir, download_dir=self.server_args.download_dir,
) )
monkey_patch_vllm_model_config() monkey_patch_vllm_model_config()
if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config()
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params()) self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
if self.model_config.model_override_args is not None: if self.model_config.model_override_args is not None:
self.vllm_model_config.hf_config.update( self.vllm_model_config.hf_config.update(
......
...@@ -338,11 +338,12 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -338,11 +338,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, quant_config) self.model = BaiChuanModel(config, position_embedding, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
def forward( def forward(
...@@ -353,7 +354,7 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -353,7 +354,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch) hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -378,7 +378,7 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -378,7 +378,7 @@ class ChatGLMForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch) hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -339,7 +339,7 @@ class CohereForCausalLM(nn.Module): ...@@ -339,7 +339,7 @@ class CohereForCausalLM(nn.Module):
forward_batch, forward_batch,
) )
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch input_ids, hidden_states, self.model.embed_tokens, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -390,7 +390,7 @@ class DbrxForCausalLM(nn.Module): ...@@ -390,7 +390,7 @@ class DbrxForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch) hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -394,7 +394,7 @@ class DeepseekForCausalLM(nn.Module): ...@@ -394,7 +394,7 @@ class DeepseekForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch) hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -763,7 +763,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -763,7 +763,7 @@ class DeepseekV2ForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch) hidden_states = self.model(input_ids, positions, forward_batch)
if not forward_batch.forward_mode.is_idle(): if not forward_batch.forward_mode.is_idle():
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -314,7 +314,7 @@ class ExaoneForCausalLM(nn.Module): ...@@ -314,7 +314,7 @@ class ExaoneForCausalLM(nn.Module):
input_ids, positions, forward_batch, input_embeds input_ids, positions, forward_batch, input_embeds
) )
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -298,7 +298,7 @@ class GemmaForCausalLM(nn.Module): ...@@ -298,7 +298,7 @@ class GemmaForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch input_ids, hidden_states, self.model.embed_tokens, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -363,7 +363,7 @@ class Gemma2ForCausalLM(nn.Module): ...@@ -363,7 +363,7 @@ class Gemma2ForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch input_ids, hidden_states, self.model.embed_tokens, forward_batch
) )
def get_attention_sliding_window_size(self): def get_attention_sliding_window_size(self):
......
...@@ -247,7 +247,7 @@ class GPT2LMHeadModel(nn.Module): ...@@ -247,7 +247,7 @@ class GPT2LMHeadModel(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch) hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -271,7 +271,7 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -271,7 +271,7 @@ class GPTBigCodeForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch) hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -304,7 +304,7 @@ class Grok1ForCausalLM(nn.Module): ...@@ -304,7 +304,7 @@ class Grok1ForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -270,7 +270,7 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -270,7 +270,7 @@ class InternLM2ForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.output.weight, forward_batch input_ids, hidden_states, self.output, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers from sglang.srt.utils import make_layers
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -258,6 +259,7 @@ class LlamaModel(nn.Module): ...@@ -258,6 +259,7 @@ class LlamaModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config,
) )
self.layers = make_layers( self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
...@@ -305,7 +307,12 @@ class LlamaForCausalLM(nn.Module): ...@@ -305,7 +307,12 @@ class LlamaForCausalLM(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"] self.torchao_config = global_server_args_dict["torchao_config"]
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.stacked_params_mapping = [ self.stacked_params_mapping = [
...@@ -329,7 +336,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -329,7 +336,7 @@ class LlamaForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if not get_embedding: if not get_embedding:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
else: else:
return self.pooler(hidden_states, forward_batch) return self.pooler(hidden_states, forward_batch)
...@@ -373,7 +380,6 @@ class LlamaForCausalLM(nn.Module): ...@@ -373,7 +380,6 @@ class LlamaForCausalLM(nn.Module):
return len(params_dict) return len(params_dict)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
embed_tokens_weight = None
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -385,12 +391,6 @@ class LlamaForCausalLM(nn.Module): ...@@ -385,12 +391,6 @@ class LlamaForCausalLM(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
load_tie_word_embeddings = (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
and "lm_head.weight" in params_dict
)
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name: if "rotary_emb.inv_freq" in name or "projector" in name:
continue continue
...@@ -423,16 +423,6 @@ class LlamaForCausalLM(nn.Module): ...@@ -423,16 +423,6 @@ class LlamaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if load_tie_word_embeddings and name == "model.embed_tokens.weight":
embed_tokens_weight = loaded_weight
if load_tie_word_embeddings:
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if embed_tokens_weight is not None:
weight_loader(param, embed_tokens_weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"])) apply_torchao_config_(self, params_dict, set(["proj.weight"]))
def get_weights_by_name( def get_weights_by_name(
...@@ -444,6 +434,17 @@ class LlamaForCausalLM(nn.Module): ...@@ -444,6 +434,17 @@ class LlamaForCausalLM(nn.Module):
For optimized performance, please use torch.save and torch.load. For optimized performance, please use torch.save and torch.load.
""" """
try: try:
if name == "lm_head.weight" and self.config.tie_word_embeddings:
logger.info(
"word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
)
return (
self.model.embed_tokens.weight.cpu()
.to(torch.float32)
.numpy()
.tolist()[:truncate_size]
)
mapped_name = name mapped_name = name
mapped_shard_id = None mapped_shard_id = None
for param_name, weight_name, shard_id in self.stacked_params_mapping: for param_name, weight_name, shard_id in self.stacked_params_mapping:
...@@ -452,54 +453,48 @@ class LlamaForCausalLM(nn.Module): ...@@ -452,54 +453,48 @@ class LlamaForCausalLM(nn.Module):
mapped_shard_id = shard_id mapped_shard_id = shard_id
break break
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
if mapped_name in params_dict: param = params_dict[mapped_name]
param = params_dict[mapped_name] if mapped_shard_id is not None:
if mapped_shard_id is not None: if mapped_shard_id in ["q", "k", "v"]:
if mapped_shard_id in ["q", "k", "v"]: num_heads = self.config.num_attention_heads // tp_size
num_heads = self.config.num_attention_heads // tp_size num_kv_heads = self.config.num_key_value_heads // tp_size
num_kv_heads = self.config.num_key_value_heads // tp_size head_dim = (
head_dim = ( self.config.hidden_size // self.config.num_attention_heads
self.config.hidden_size // self.config.num_attention_heads )
) if mapped_shard_id == "q":
if mapped_shard_id == "q": offset = 0
offset = 0 size = num_heads * head_dim
size = num_heads * head_dim elif mapped_shard_id == "k":
elif mapped_shard_id == "k": offset = num_heads * head_dim
offset = num_heads * head_dim size = num_kv_heads * head_dim
size = num_kv_heads * head_dim elif mapped_shard_id == "v":
elif mapped_shard_id == "v": offset = (num_heads + num_kv_heads) * head_dim
offset = (num_heads + num_kv_heads) * head_dim size = num_kv_heads * head_dim
size = num_kv_heads * head_dim weight = param.data.narrow(0, offset, size)
weight = param.data.narrow(0, offset, size) elif mapped_shard_id in [0, 1]:
elif mapped_shard_id in [0, 1]: intermediate_size = self.config.intermediate_size
intermediate_size = self.config.intermediate_size slice_size = intermediate_size // tp_size
hidden_size = self.config.hidden_size if mapped_shard_id == 0: # gate_proj
slice_size = intermediate_size // tp_size offset = 0
if mapped_shard_id == 0: # gate_proj size = slice_size
offset = 0 elif mapped_shard_id == 1: # up_proj
size = slice_size offset = slice_size
elif mapped_shard_id == 1: # up_proj size = slice_size
offset = slice_size
size = slice_size weight = param.data.narrow(0, offset, size)
weight = param.data.narrow(0, offset, size)
else:
weight = param.data
else: else:
weight = param.data weight = param.data
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
gathered_weights = [
torch.zeros_like(weight) for _ in range(tp_size)
]
torch.distributed.all_gather(gathered_weights, weight)
weight = torch.cat(gathered_weights, dim=1)
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
else: else:
return None weight = param.data
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
except Exception as e: gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
torch.distributed.all_gather(gathered_weights, weight)
weight = torch.cat(gathered_weights, dim=1)
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
except Exception:
logger.error( logger.error(
f"Error getting weights by name {name} in LlamaForCausalLM: {e}" f"Error getting weights by name {name} in LlamaForCausalLM: {get_exception_traceback()}"
) )
return None return None
......
...@@ -308,12 +308,10 @@ class MiniCPMForCausalLM(nn.Module): ...@@ -308,12 +308,10 @@ class MiniCPMForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight lm_head = self.model.embed_tokens
else: else:
lm_head_weight = self.lm_head.weight lm_head = self.lm_head
return self.logits_processor( return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
input_ids, hidden_states, lm_head_weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment