Commit 45eecde6 authored by mxCynic's avatar mxCynic
Browse files

fix: Use model_type to determine whether to load scale related parameters

parent 0fca9576
...@@ -90,6 +90,17 @@ class JiugeMetaFromLlama(JiugeMetaCStruct): ...@@ -90,6 +90,17 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
dt_ = DataType.INFINI_DTYPE_BF16 dt_ = DataType.INFINI_DTYPE_BF16
else: else:
dt_ = DataType.INFINI_DTYPE_F16 dt_ = DataType.INFINI_DTYPE_F16
scale_input = 1.0
scale_output = 1.0
scale_o = 1.0
scale_down = 1.0
if "fm9g" == config["model_type"]:
scale_input = config["scale_emb"]
scale_output = config["hidden_size"] // config["dim_model_base"]
scale_o = config["scale_depth"] / math.sqrt(config["num_hidden_layers"])
scale_down = config["scale_depth"] / math.sqrt(config["num_hidden_layers"])
super().__init__( super().__init__(
dt_logits=dt_, dt_logits=dt_,
nlayer=config["num_hidden_layers"], nlayer=config["num_hidden_layers"],
...@@ -108,22 +119,6 @@ class JiugeMetaFromLlama(JiugeMetaCStruct): ...@@ -108,22 +119,6 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
dvoc=config["vocab_size"], dvoc=config["vocab_size"],
epsilon=config["rms_norm_eps"], epsilon=config["rms_norm_eps"],
theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), theta=(config["rope_theta"] if "rope_theta" in config else 100000.0),
scale_input=(config["scale_emb"] if "scale_emb" in config else 1.0),
scale_output=(
config["hidden_size"] // config["dim_model_base"]
if "dim_model_base" in config
else 1.0
),
scale_o=(
config["scale_depth"] / math.sqrt(config["num_hidden_layers"])
if "scale_depth" in config
else 1.0
),
scale_down=(
config["scale_depth"] / math.sqrt(config["num_hidden_layers"])
if "scale_depth" in config
else 1.0
),
end_token=2, end_token=2,
) )
self.torch_dtype_logits = dtype self.torch_dtype_logits = dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment