You need to sign in or sign up before continuing.
Unverified Commit 2f35b2b9 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #18 from mxCynic/main

fix: fix 9G4B model by add coefficent to weight tensor when load
parents a5deda33 45eecde6
......@@ -20,6 +20,7 @@ import safetensors
import sys
import time
import json
import math
import torch
import transformers
......@@ -89,6 +90,17 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
dt_ = DataType.INFINI_DTYPE_BF16
else:
dt_ = DataType.INFINI_DTYPE_F16
scale_input = 1.0
scale_output = 1.0
scale_o = 1.0
scale_down = 1.0
if "fm9g" == config["model_type"]:
scale_input = config["scale_emb"]
scale_output = config["hidden_size"] // config["dim_model_base"]
scale_o = config["scale_depth"] / math.sqrt(config["num_hidden_layers"])
scale_down = config["scale_depth"] / math.sqrt(config["num_hidden_layers"])
super().__init__(
dt_logits=dt_,
nlayer=config["num_hidden_layers"],
......@@ -101,7 +113,9 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
),
dh=config["hidden_size"] // config["num_attention_heads"],
di=config["intermediate_size"],
dctx=config["max_position_embeddings"] if max_tokens is None else max_tokens,
dctx=(
config["max_position_embeddings"] if max_tokens is None else max_tokens
),
dvoc=config["vocab_size"],
epsilon=config["rms_norm_eps"],
theta=(config["rope_theta"] if "rope_theta" in config else 100000.0),
......@@ -127,6 +141,10 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
dh = meta.dh
d = meta.d
di = meta.di
scale_input = meta.scale_input
scale_output = meta.scale_output
scale_o = meta.scale_o
scale_down = meta.scale_down
assert nh % nkvh == 0
assert nh % ndev == 0
assert nkvh % ndev == 0
......@@ -161,9 +179,13 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
)
self.transpose_linear_weights = 1 if transpose_weight else 0
self.nlayer = nlayer
self.input_embd_tensor = state_dict[input_embd_naming].to(torch_dt_logits)
self.input_embd_tensor = (
state_dict[input_embd_naming].to(torch_dt_logits) * scale_input
)
self.input_embd = self.input_embd_tensor.data_ptr()
self.output_norm_tensor = state_dict[naming.output_norm()].to(torch_dt_norm)
self.output_norm_tensor = (
state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output
)
self.output_norm = self.output_norm_tensor.data_ptr()
self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat)
if not transpose_weight:
......@@ -260,6 +282,7 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
.to(torch_dt_mat)
.contiguous()
)
* scale_o
for i in range(nlayer)
]
self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)]
......@@ -310,6 +333,7 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
.to(torch_dt_mat)
.contiguous()
)
* scale_down
for i in range(nlayer)
]
self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)]
......@@ -358,7 +382,9 @@ class JiugeBatchedTask:
class JiugeForCauslLM:
def __init__(self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None):
def __init__(
self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None
):
def load_all_safetensors_from_dir(dir_path_: str):
tensors_ = {}
dir_path_ = Path(dir_path_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment