Unverified Commit f780504d authored by Chenhui Zhang's avatar Chenhui Zhang Committed by GitHub
Browse files

fix weigit loading for GQA with TP (#2379)

parent bfc072ad
......@@ -423,7 +423,10 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = shard_offset // param.pack_factor
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
shard_id = tp_rank // self.num_kv_head_replicas
if loaded_shard_id == "q":
shard_id = tp_rank
else:
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment