Commit 33f37e9f authored by zhuwenwen's avatar zhuwenwen
Browse files

update the layout of weight_loader_v2 in ColumnParallelLinear and update the...

update the layout of weight_loader_v2 in ColumnParallelLinear and update the initialization of is_quantification
parent e35a9e99
......@@ -413,6 +413,7 @@ class ReplicatedLinear(LinearBase):
})
else:
self.register_parameter("bias", None)
self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# If the weight on disk does not have a shape, give it one
......@@ -431,8 +432,7 @@ class ReplicatedLinear(LinearBase):
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if envs.VLLM_USE_NN and not is_quantization:
if envs.VLLM_USE_NN and not self.is_quantization:
loaded_weight = loaded_weight.t()
assert param.size() == loaded_weight.size(), (
......@@ -592,6 +592,7 @@ class ColumnParallelLinear(LinearBase):
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
......@@ -602,7 +603,6 @@ class ColumnParallelLinear(LinearBase):
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
......@@ -621,7 +621,7 @@ class ColumnParallelLinear(LinearBase):
param_data = param.data
if output_dim is not None and not is_sharded_weight:
if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
if not envs.VLLM_USE_NN or len(param_data.shape)==1 or self.is_quantization:
shard_size = param_data.shape[output_dim]
else:
shard_size = param_data.shape[int(not(output_dim))]
......@@ -634,7 +634,7 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if envs.VLLM_USE_NN and not is_quantization:
if envs.VLLM_USE_NN and not self.is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape
......@@ -647,7 +647,7 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight=loaded_weight)
param.load_column_parallel_weight(loaded_weight=loaded_weight if not envs.VLLM_USE_NN or self.is_quantization else loaded_weight.t())
def forward(
self,
......@@ -812,6 +812,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
def weight_loader(self,
param: Parameter,
......@@ -851,7 +852,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
output_dim = getattr(param, "output_dim", None)
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if loaded_shard_id is None:
# Loaded weight is already fused on disk (mlp).
......@@ -931,7 +931,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id
if not envs.VLLM_USE_NN or is_quantization:
if not envs.VLLM_USE_NN or self.is_quantization:
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
else:
param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size)
......@@ -953,7 +953,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")
if envs.VLLM_USE_NN and not is_quantization:
if envs.VLLM_USE_NN and not self.is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape
......@@ -996,7 +996,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if loaded_shard_id is None:
if isinstance(param, PerTensorScaleParameter):
......@@ -1041,7 +1040,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=self.tp_rank,
is_quantization=is_quantization)
is_quantization=self.is_quantization)
class QKVParallelLinear(ColumnParallelLinear):
......@@ -1123,6 +1122,7 @@ class QKVParallelLinear(ColumnParallelLinear):
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
......@@ -1181,7 +1181,6 @@ class QKVParallelLinear(ColumnParallelLinear):
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
......@@ -1227,7 +1226,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=self.tp_rank,
is_quantization=is_quantization)
is_quantization=self.is_quantization)
def weight_loader(self,
param: Parameter,
......@@ -1268,7 +1267,6 @@ class QKVParallelLinear(ColumnParallelLinear):
# Special case for per-tensor scales in fused case.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv).
......@@ -1375,7 +1373,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id)
if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
if not envs.VLLM_USE_NN or len(param_data.shape)==1 or self.is_quantization:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
......@@ -1404,7 +1402,7 @@ class QKVParallelLinear(ColumnParallelLinear):
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")
if envs.VLLM_USE_NN and not is_quantization:
if envs.VLLM_USE_NN and not self.is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape
......@@ -1507,6 +1505,7 @@ class RowParallelLinear(LinearBase):
self.register_parameter("bias", None)
self.update_param_tp_status()
self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
input_dim = getattr(param, "input_dim", None)
......@@ -1530,11 +1529,9 @@ class RowParallelLinear(LinearBase):
self.tp_size)
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
param_data = param.data
if input_dim is not None and not is_sharded_weight:
if not envs.VLLM_USE_NN or is_quantization:
if not envs.VLLM_USE_NN or self.is_quantization:
shard_size = param_data.shape[input_dim]
else:
shard_size = param_data.shape[int(not(input_dim))]
......@@ -1547,7 +1544,7 @@ class RowParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if envs.VLLM_USE_NN and not is_quantization:
if envs.VLLM_USE_NN and not self.is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape
......@@ -1562,7 +1559,7 @@ class RowParallelLinear(LinearBase):
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_row_parallel_weight(loaded_weight=loaded_weight.t() if envs.VLLM_USE_NN else loaded_weight)
param.load_row_parallel_weight(loaded_weight=loaded_weight if not envs.VLLM_USE_NN or self.is_quantization else loaded_weight.t())
def forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment