Unverified Commit 2f35b2b9 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #18 from mxCynic/main

fix: fix 9G4B model by add coefficent to weight tensor when load
parents a5deda33 45eecde6
...@@ -20,6 +20,7 @@ import safetensors ...@@ -20,6 +20,7 @@ import safetensors
import sys import sys
import time import time
import json import json
import math
import torch import torch
import transformers import transformers
...@@ -89,6 +90,17 @@ class JiugeMetaFromLlama(JiugeMetaCStruct): ...@@ -89,6 +90,17 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
dt_ = DataType.INFINI_DTYPE_BF16 dt_ = DataType.INFINI_DTYPE_BF16
else: else:
dt_ = DataType.INFINI_DTYPE_F16 dt_ = DataType.INFINI_DTYPE_F16
scale_input = 1.0
scale_output = 1.0
scale_o = 1.0
scale_down = 1.0
if "fm9g" == config["model_type"]:
scale_input = config["scale_emb"]
scale_output = config["hidden_size"] // config["dim_model_base"]
scale_o = config["scale_depth"] / math.sqrt(config["num_hidden_layers"])
scale_down = config["scale_depth"] / math.sqrt(config["num_hidden_layers"])
super().__init__( super().__init__(
dt_logits=dt_, dt_logits=dt_,
nlayer=config["num_hidden_layers"], nlayer=config["num_hidden_layers"],
...@@ -101,7 +113,9 @@ class JiugeMetaFromLlama(JiugeMetaCStruct): ...@@ -101,7 +113,9 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
), ),
dh=config["hidden_size"] // config["num_attention_heads"], dh=config["hidden_size"] // config["num_attention_heads"],
di=config["intermediate_size"], di=config["intermediate_size"],
dctx=config["max_position_embeddings"] if max_tokens is None else max_tokens, dctx=(
config["max_position_embeddings"] if max_tokens is None else max_tokens
),
dvoc=config["vocab_size"], dvoc=config["vocab_size"],
epsilon=config["rms_norm_eps"], epsilon=config["rms_norm_eps"],
theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), theta=(config["rope_theta"] if "rope_theta" in config else 100000.0),
...@@ -127,6 +141,10 @@ class JiugeWeightsImpl(JiugeWeightsCStruct): ...@@ -127,6 +141,10 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
dh = meta.dh dh = meta.dh
d = meta.d d = meta.d
di = meta.di di = meta.di
scale_input = meta.scale_input
scale_output = meta.scale_output
scale_o = meta.scale_o
scale_down = meta.scale_down
assert nh % nkvh == 0 assert nh % nkvh == 0
assert nh % ndev == 0 assert nh % ndev == 0
assert nkvh % ndev == 0 assert nkvh % ndev == 0
...@@ -161,9 +179,13 @@ class JiugeWeightsImpl(JiugeWeightsCStruct): ...@@ -161,9 +179,13 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
) )
self.transpose_linear_weights = 1 if transpose_weight else 0 self.transpose_linear_weights = 1 if transpose_weight else 0
self.nlayer = nlayer self.nlayer = nlayer
self.input_embd_tensor = state_dict[input_embd_naming].to(torch_dt_logits) self.input_embd_tensor = (
state_dict[input_embd_naming].to(torch_dt_logits) * scale_input
)
self.input_embd = self.input_embd_tensor.data_ptr() self.input_embd = self.input_embd_tensor.data_ptr()
self.output_norm_tensor = state_dict[naming.output_norm()].to(torch_dt_norm) self.output_norm_tensor = (
state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output
)
self.output_norm = self.output_norm_tensor.data_ptr() self.output_norm = self.output_norm_tensor.data_ptr()
self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat)
if not transpose_weight: if not transpose_weight:
...@@ -260,6 +282,7 @@ class JiugeWeightsImpl(JiugeWeightsCStruct): ...@@ -260,6 +282,7 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
.to(torch_dt_mat) .to(torch_dt_mat)
.contiguous() .contiguous()
) )
* scale_o
for i in range(nlayer) for i in range(nlayer)
] ]
self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)] self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)]
...@@ -310,6 +333,7 @@ class JiugeWeightsImpl(JiugeWeightsCStruct): ...@@ -310,6 +333,7 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
.to(torch_dt_mat) .to(torch_dt_mat)
.contiguous() .contiguous()
) )
* scale_down
for i in range(nlayer) for i in range(nlayer)
] ]
self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)] self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)]
...@@ -358,7 +382,9 @@ class JiugeBatchedTask: ...@@ -358,7 +382,9 @@ class JiugeBatchedTask:
class JiugeForCauslLM: class JiugeForCauslLM:
def __init__(self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None): def __init__(
self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None
):
def load_all_safetensors_from_dir(dir_path_: str): def load_all_safetensors_from_dir(dir_path_: str):
tensors_ = {} tensors_ = {}
dir_path_ = Path(dir_path_) dir_path_ = Path(dir_path_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment