Unverified Commit 9a9d4424 authored by xjx's avatar xjx Committed by GitHub
Browse files

Enable bnb for multiple indices weight (#35838)


Signed-off-by: default avatarxjx <493337577@qq.com>
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent f7da9cdf
......@@ -744,10 +744,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
)
current_shard_offset = 0
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
if use_bitsandbytes_4bit and isinstance(loaded_shard_id, tuple):
if (
use_bitsandbytes_4bit
and isinstance(loaded_shard_id, tuple)
and self.tp_size > 1
):
raise NotImplementedError(
"Shard id with multiple indices is not supported "
"for BNB quantization yet."
"for BNB quantization with TP yet."
)
shard_offsets: list[tuple[int, int, int]] = []
for i, output_size in enumerate(output_sizes):
......@@ -815,9 +819,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
index = list(itertools.accumulate([0] + self.output_sizes))
orig_offsets = {
str(i): (index[i], size) for i, size in enumerate(self.output_sizes)
}
orig_offsets["total"] = (self.output_size, 0)
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_offsets, str(loaded_shard_id)
)
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
start_idx = self.tp_rank * shard_size
if not is_sharded_weight:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment