Unverified Commit 2b25b7d2 authored by Szymon Ożóg's avatar Szymon Ożóg Committed by GitHub
Browse files

Fix initializing GGUF weights for ColumnParallelLinear when using tensor parallel > 1 (#13023)

parent 6c4dbe23
......@@ -335,6 +335,12 @@ class ColumnParallelLinear(LinearBase):
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
......@@ -343,13 +349,12 @@ class ColumnParallelLinear(LinearBase):
# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
final_shape = list(loaded_weight.shape)
if output_dim is not None:
tp_size = get_tensor_model_parallel_world_size()
assert final_shape[output_dim] % tp_size == 0
final_shape[output_dim] = final_shape[output_dim] // tp_size
param.materialize(final_shape, dtype=loaded_weight.dtype)
param_data = param.data
if output_dim is not None and not is_sharded_weight:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment