Unverified Commit fcb31c1a authored by Shinichi Hemmi's avatar Shinichi Hemmi Committed by GitHub
Browse files

[Bugfix] Properly initialize `PerTensorScaleParameter` for fused-on-disk checkpoints (#39765)


Signed-off-by: default avatarHemmi Shinichi <shemmi@preferred.jp>
Signed-off-by: default avatarShinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent d886c26d
...@@ -916,9 +916,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -916,9 +916,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight=loaded_weight, shard_id=idx loaded_weight=loaded_weight, shard_id=idx
) )
else: else:
param.load_merged_column_weight( # When weights are already fused on disk (e.g. Phi-3's
loaded_weight=loaded_weight, shard_id=0 # gate_up_proj), there is only a single scale for the
) # entire fused matrix. Fill all slots with this scale
# to ensure that any subsequent reduction (like .max())
# works correctly while preserving the parameter shape.
for idx in range(param.data.shape[0]):
param.load_merged_column_weight(
loaded_weight=loaded_weight, shard_id=idx
)
return return
elif type(param) in (RowvLLMParameter, BasevLLMParameter): elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight) param.load_merged_column_weight(loaded_weight=loaded_weight)
...@@ -1130,9 +1136,15 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1130,9 +1136,15 @@ class QKVParallelLinear(ColumnParallelLinear):
self.validate_shard_id(loaded_shard_id) self.validate_shard_id(loaded_shard_id)
if loaded_shard_id is None: # special case for certain models if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter): if isinstance(param, PerTensorScaleParameter):
param.load_qkv_weight( # When weights are already fused on disk (e.g. Phi-3's
loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank # qkv_proj), there is only a single scale for the entire
) # fused matrix. Fill all slots (q, k, v) with this scale
# to ensure that any subsequent reduction (like .max())
# works correctly while preserving the parameter shape.
for idx in range(param.data.shape[0]):
param.load_qkv_weight(
loaded_weight=loaded_weight, shard_id=idx, tp_rank=self.tp_rank
)
return return
elif type(param) in (RowvLLMParameter, BasevLLMParameter): elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank) param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment