Unverified Commit c9ba8104 authored by Kevin_Xiong's avatar Kevin_Xiong Committed by GitHub
Browse files

[Bugfix] weight loading use correct tp_group with patch_tensor_parallel_group (#21024)


Signed-off-by: default avatarKevinXiong-C <kevin_xiong1997@outlook.com>
parent 4e7dfbe7
...@@ -452,8 +452,10 @@ class ColumnParallelLinear(LinearBase): ...@@ -452,8 +452,10 @@ class ColumnParallelLinear(LinearBase):
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.tp_rank = get_tensor_model_parallel_rank()
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
is_sharded_weight = getattr(param, "is_sharded_weight", False) is_sharded_weight = getattr(param, "is_sharded_weight", False)
...@@ -472,15 +474,15 @@ class ColumnParallelLinear(LinearBase): ...@@ -472,15 +474,15 @@ class ColumnParallelLinear(LinearBase):
if is_gguf_weight and isinstance(param, UninitializedParameter): if is_gguf_weight and isinstance(param, UninitializedParameter):
final_shape = list(loaded_weight.shape) final_shape = list(loaded_weight.shape)
if output_dim is not None: if output_dim is not None:
tp_size = get_tensor_model_parallel_world_size() assert final_shape[output_dim] % self.tp_size == 0
assert final_shape[output_dim] % tp_size == 0 final_shape[output_dim] = (final_shape[output_dim] //
final_shape[output_dim] = final_shape[output_dim] // tp_size self.tp_size)
param.materialize(final_shape, dtype=loaded_weight.dtype) param.materialize(final_shape, dtype=loaded_weight.dtype)
param_data = param.data param_data = param.data
if output_dim is not None and not is_sharded_weight: if output_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[output_dim] shard_size = param_data.shape[output_dim]
start_idx = tp_rank * shard_size start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
...@@ -565,8 +567,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -565,8 +567,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
return_bias: bool = True, return_bias: bool = True,
): ):
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes) self.tp_rank = get_tensor_model_parallel_rank()
assert all(output_size % self.tp_size == 0
for output_size in output_sizes)
super().__init__(input_size=input_size, super().__init__(input_size=input_size,
output_size=sum(output_sizes), output_size=sum(output_sizes),
bias=bias, bias=bias,
...@@ -598,12 +603,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -598,12 +603,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
return return
if is_gguf_weight: if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
shard_size = loaded_weight.size(output_dim) // tp_size shard_size = loaded_weight.size(output_dim) // self.tp_size
start_idx = tp_rank * shard_size start_idx = self.tp_rank * shard_size
if loaded_shard_id is not None: if loaded_shard_id is not None:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
...@@ -669,11 +672,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -669,11 +672,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
return return
assert loaded_shard_id < len(self.output_sizes) assert loaded_shard_id < len(self.output_sizes)
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
if output_dim is not None: if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_offset = (sum(self.output_sizes[:loaded_shard_id]) //
shard_size = self.output_sizes[loaded_shard_id] // tp_size self.tp_size)
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
# Special case for quantization. # Special case for quantization.
# If quantized, we need to adjust the offset and size to account # If quantized, we need to adjust the offset and size to account
# for the packing. # for the packing.
...@@ -701,7 +703,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -701,7 +703,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
start_idx = tp_rank * shard_size start_idx = self.tp_rank * shard_size
if not is_sharded_weight: if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
...@@ -991,12 +993,9 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -991,12 +993,9 @@ class QKVParallelLinear(ColumnParallelLinear):
return return
if is_gguf_weight: if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
shard_size = loaded_weight.size(output_dim) // tp_size shard_size = loaded_weight.size(output_dim) // self.tp_size
start_idx = tp_rank * shard_size start_idx = self.tp_rank * shard_size
if loaded_shard_id is not None: if loaded_shard_id is not None:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
...@@ -1071,7 +1070,6 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1071,7 +1070,6 @@ class QKVParallelLinear(ColumnParallelLinear):
self.weight_loader(param, loaded_weight_shard, shard_id) self.weight_loader(param, loaded_weight_shard, shard_id)
return return
tp_rank = get_tensor_model_parallel_rank()
assert loaded_shard_id in ["q", "k", "v"] assert loaded_shard_id in ["q", "k", "v"]
# If output dim is defined, use the default loading process. # If output dim is defined, use the default loading process.
...@@ -1123,9 +1121,9 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1123,9 +1121,9 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
if loaded_shard_id == "q": if loaded_shard_id == "q":
shard_id = tp_rank shard_id = self.tp_rank
else: else:
shard_id = tp_rank // self.num_kv_head_replicas shard_id = self.tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size start_idx = shard_id * shard_size
if not is_sharded_weight: if not is_sharded_weight:
...@@ -1245,8 +1243,6 @@ class RowParallelLinear(LinearBase): ...@@ -1245,8 +1243,6 @@ class RowParallelLinear(LinearBase):
self.register_parameter("bias", None) self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None) input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False) is_sharded_weight = getattr(param, "is_sharded_weight", False)
...@@ -1264,13 +1260,14 @@ class RowParallelLinear(LinearBase): ...@@ -1264,13 +1260,14 @@ class RowParallelLinear(LinearBase):
if is_gguf_weight and isinstance(param, UninitializedParameter): if is_gguf_weight and isinstance(param, UninitializedParameter):
weight_shape = list(loaded_weight.shape) weight_shape = list(loaded_weight.shape)
if input_dim: if input_dim:
weight_shape[input_dim] = weight_shape[input_dim] // tp_size weight_shape[input_dim] = (weight_shape[input_dim] //
self.tp_size)
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
param_data = param.data param_data = param.data
if input_dim is not None and not is_sharded_weight: if input_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[input_dim] shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx, loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size) shard_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment