Unverified Commit 98e1a43a authored by kkyyxhll's avatar kkyyxhll Committed by GitHub
Browse files

[Bugfix][Quantization] Fix PerTensorScale loading with tuple shard_id in...


[Bugfix][Quantization] Fix PerTensorScale loading with tuple shard_id in MergedColumnParallelLinear (#38517)
Signed-off-by: default avatarloukang <loukang@xiaohongshu.com>
parent 729eb59f
......@@ -910,7 +910,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self.validate_shard_id(loaded_shard_id)
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
if isinstance(loaded_shard_id, tuple):
for idx in loaded_shard_id:
param.load_merged_column_weight(
loaded_weight=loaded_weight, shard_id=idx
)
else:
param.load_merged_column_weight(
loaded_weight=loaded_weight, shard_id=0
)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment