Commit 7bb28d54 authored by zhuwenwen's avatar zhuwenwen
Browse files

update linear.py

parent 54ab90dd
...@@ -825,8 +825,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -825,8 +825,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
else: else:
param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size) param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = self.tp_rank * shard_size start_idx = self.tp_rank * shard_size
if not is_sharded_weight: if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment