Commit 0fca9576 authored by mxCynic's avatar mxCynic
Browse files

fix: fix 9G4B model by add coefficent to weight tensor when load

parent a5deda33
...@@ -20,6 +20,7 @@ import safetensors ...@@ -20,6 +20,7 @@ import safetensors
import sys import sys
import time import time
import json import json
import math
import torch import torch
import transformers import transformers
...@@ -101,10 +102,28 @@ class JiugeMetaFromLlama(JiugeMetaCStruct): ...@@ -101,10 +102,28 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
), ),
dh=config["hidden_size"] // config["num_attention_heads"], dh=config["hidden_size"] // config["num_attention_heads"],
di=config["intermediate_size"], di=config["intermediate_size"],
dctx=config["max_position_embeddings"] if max_tokens is None else max_tokens, dctx=(
config["max_position_embeddings"] if max_tokens is None else max_tokens
),
dvoc=config["vocab_size"], dvoc=config["vocab_size"],
epsilon=config["rms_norm_eps"], epsilon=config["rms_norm_eps"],
theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), theta=(config["rope_theta"] if "rope_theta" in config else 100000.0),
scale_input=(config["scale_emb"] if "scale_emb" in config else 1.0),
scale_output=(
config["hidden_size"] // config["dim_model_base"]
if "dim_model_base" in config
else 1.0
),
scale_o=(
config["scale_depth"] / math.sqrt(config["num_hidden_layers"])
if "scale_depth" in config
else 1.0
),
scale_down=(
config["scale_depth"] / math.sqrt(config["num_hidden_layers"])
if "scale_depth" in config
else 1.0
),
end_token=2, end_token=2,
) )
self.torch_dtype_logits = dtype self.torch_dtype_logits = dtype
...@@ -127,6 +146,10 @@ class JiugeWeightsImpl(JiugeWeightsCStruct): ...@@ -127,6 +146,10 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
dh = meta.dh dh = meta.dh
d = meta.d d = meta.d
di = meta.di di = meta.di
scale_input = meta.scale_input
scale_output = meta.scale_output
scale_o = meta.scale_o
scale_down = meta.scale_down
assert nh % nkvh == 0 assert nh % nkvh == 0
assert nh % ndev == 0 assert nh % ndev == 0
assert nkvh % ndev == 0 assert nkvh % ndev == 0
...@@ -161,9 +184,13 @@ class JiugeWeightsImpl(JiugeWeightsCStruct): ...@@ -161,9 +184,13 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
) )
self.transpose_linear_weights = 1 if transpose_weight else 0 self.transpose_linear_weights = 1 if transpose_weight else 0
self.nlayer = nlayer self.nlayer = nlayer
self.input_embd_tensor = state_dict[input_embd_naming].to(torch_dt_logits) self.input_embd_tensor = (
state_dict[input_embd_naming].to(torch_dt_logits) * scale_input
)
self.input_embd = self.input_embd_tensor.data_ptr() self.input_embd = self.input_embd_tensor.data_ptr()
self.output_norm_tensor = state_dict[naming.output_norm()].to(torch_dt_norm) self.output_norm_tensor = (
state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output
)
self.output_norm = self.output_norm_tensor.data_ptr() self.output_norm = self.output_norm_tensor.data_ptr()
self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat)
if not transpose_weight: if not transpose_weight:
...@@ -260,6 +287,7 @@ class JiugeWeightsImpl(JiugeWeightsCStruct): ...@@ -260,6 +287,7 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
.to(torch_dt_mat) .to(torch_dt_mat)
.contiguous() .contiguous()
) )
* scale_o
for i in range(nlayer) for i in range(nlayer)
] ]
self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)] self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)]
...@@ -310,6 +338,7 @@ class JiugeWeightsImpl(JiugeWeightsCStruct): ...@@ -310,6 +338,7 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
.to(torch_dt_mat) .to(torch_dt_mat)
.contiguous() .contiguous()
) )
* scale_down
for i in range(nlayer) for i in range(nlayer)
] ]
self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)] self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)]
...@@ -358,7 +387,9 @@ class JiugeBatchedTask: ...@@ -358,7 +387,9 @@ class JiugeBatchedTask:
class JiugeForCauslLM: class JiugeForCauslLM:
def __init__(self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None): def __init__(
self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None
):
def load_all_safetensors_from_dir(dir_path_: str): def load_all_safetensors_from_dir(dir_path_: str):
tensors_ = {} tensors_ = {}
dir_path_ = Path(dir_path_) dir_path_ = Path(dir_path_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment