Unverified Commit eb825c1e authored by lirui's avatar lirui Committed by GitHub
Browse files

Fix #1474 - AssertionError:assert param_slice.shape == loaded_weight.shape (#1631)

parent 1b290ace
...@@ -250,7 +250,7 @@ class GPTJForCausalLM(nn.Module): ...@@ -250,7 +250,7 @@ class GPTJForCausalLM(nn.Module):
if att_weight_name not in name: if att_weight_name not in name:
continue continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")] param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[1] shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)] (tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size * param_slice = param.data[shard_size * stride_id:shard_size *
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment