Unverified Commit fcb31c1a authored by Shinichi Hemmi's avatar Shinichi Hemmi Committed by GitHub
Browse files

[Bugfix] Properly initialize `PerTensorScaleParameter` for fused-on-disk checkpoints (#39765)


Signed-off-by: default avatarHemmi Shinichi <shemmi@preferred.jp>
Signed-off-by: default avatarShinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent d886c26d
......@@ -916,8 +916,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight=loaded_weight, shard_id=idx
)
else:
# When weights are already fused on disk (e.g. Phi-3's
# gate_up_proj), there is only a single scale for the
# entire fused matrix. Fill all slots with this scale
# to ensure that any subsequent reduction (like .max())
# works correctly while preserving the parameter shape.
for idx in range(param.data.shape[0]):
param.load_merged_column_weight(
loaded_weight=loaded_weight, shard_id=0
loaded_weight=loaded_weight, shard_id=idx
)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
......@@ -1130,8 +1136,14 @@ class QKVParallelLinear(ColumnParallelLinear):
self.validate_shard_id(loaded_shard_id)
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
# When weights are already fused on disk (e.g. Phi-3's
# qkv_proj), there is only a single scale for the entire
# fused matrix. Fill all slots (q, k, v) with this scale
# to ensure that any subsequent reduction (like .max())
# works correctly while preserving the parameter shape.
for idx in range(param.data.shape[0]):
param.load_qkv_weight(
loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
loaded_weight=loaded_weight, shard_id=idx, tp_rank=self.tp_rank
)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment