Unverified Commit 7eb719df authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Bugfix]Fix Phi-3 BNB online quantization (#10417)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 284203f1
...@@ -470,7 +470,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -470,7 +470,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None: if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv/mlp). # Loaded weight is already fused on disk (mlp).
# (e.g., Phi-3's gate_up_proj).
if output_dim is None: if output_dim is None:
if needs_scalar_to_array: if needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array( param_data, loaded_weight = adjust_scalar_to_fused_array(
...@@ -480,6 +481,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -480,6 +481,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
return return
current_shard_offset = 0 current_shard_offset = 0
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
shard_offsets: List[Tuple[int, int, int]] = [] shard_offsets: List[Tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes): for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size)) shard_offsets.append((i, current_shard_offset, output_size))
...@@ -495,7 +498,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -495,7 +498,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Marlin. # Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim] // 2
shard_offset = shard_size * shard_id
loaded_weight_shard = loaded_weight.narrow( loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size) output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id) self.weight_loader(param, loaded_weight_shard, shard_id)
...@@ -808,7 +813,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -808,7 +813,8 @@ class QKVParallelLinear(ColumnParallelLinear):
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None: if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv/mlp). # Loaded weight is already fused on disk (qkv).
# (e.g., Phi-3's qkv_proj).
if output_dim is None: if output_dim is None:
if needs_scalar_to_array: if needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array( param_data, loaded_weight = adjust_scalar_to_fused_array(
......
...@@ -14,3 +14,13 @@ class Phi3ForCausalLM(LlamaForCausalLM): ...@@ -14,3 +14,13 @@ class Phi3ForCausalLM(LlamaForCausalLM):
"gate_up_proj", "gate_up_proj",
], ],
} }
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_up_proj.",
".down_proj.",
".qkv_proj.",
".o_proj.",
]
# Initialize an empty dict when there is no stacked parameter mapping.
bitsandbytes_stacked_params_mapping = {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment