Commit 33f37e9f authored by zhuwenwen's avatar zhuwenwen
Browse files

update the layout of weight_loader_v2 in ColumnParallelLinear and update the...

update the layout of weight_loader_v2 in ColumnParallelLinear and update the initialization of is_quantification
parent e35a9e99
...@@ -413,6 +413,7 @@ class ReplicatedLinear(LinearBase): ...@@ -413,6 +413,7 @@ class ReplicatedLinear(LinearBase):
}) })
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# If the weight on disk does not have a shape, give it one # If the weight on disk does not have a shape, give it one
...@@ -431,8 +432,7 @@ class ReplicatedLinear(LinearBase): ...@@ -431,8 +432,7 @@ class ReplicatedLinear(LinearBase):
if len(loaded_weight.shape) == 0: if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod) if envs.VLLM_USE_NN and not self.is_quantization:
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t() loaded_weight = loaded_weight.t()
assert param.size() == loaded_weight.size(), ( assert param.size() == loaded_weight.size(), (
...@@ -592,6 +592,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -592,6 +592,7 @@ class ColumnParallelLinear(LinearBase):
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.update_param_tp_status() self.update_param_tp_status()
self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
...@@ -602,7 +603,6 @@ class ColumnParallelLinear(LinearBase): ...@@ -602,7 +603,6 @@ class ColumnParallelLinear(LinearBase):
# bitsandbytes loads the weights of the specific portion # bitsandbytes loads the weights of the specific portion
# no need to narrow # no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
# Special case for GGUF # Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight = getattr(param, "is_gguf_weight", False)
...@@ -621,7 +621,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -621,7 +621,7 @@ class ColumnParallelLinear(LinearBase):
param_data = param.data param_data = param.data
if output_dim is not None and not is_sharded_weight: if output_dim is not None and not is_sharded_weight:
if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization: if not envs.VLLM_USE_NN or len(param_data.shape)==1 or self.is_quantization:
shard_size = param_data.shape[output_dim] shard_size = param_data.shape[output_dim]
else: else:
shard_size = param_data.shape[int(not(output_dim))] shard_size = param_data.shape[int(not(output_dim))]
...@@ -634,7 +634,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -634,7 +634,7 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0: if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
if envs.VLLM_USE_NN and not is_quantization: if envs.VLLM_USE_NN and not self.is_quantization:
loaded_weight = loaded_weight.t() loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
...@@ -647,7 +647,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -647,7 +647,7 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0: if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1 assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight=loaded_weight) param.load_column_parallel_weight(loaded_weight=loaded_weight if not envs.VLLM_USE_NN or self.is_quantization else loaded_weight.t())
def forward( def forward(
self, self,
...@@ -812,6 +812,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -812,6 +812,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix=prefix, prefix=prefix,
return_bias=return_bias, return_bias=return_bias,
disable_tp=disable_tp) disable_tp=disable_tp)
self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -851,7 +852,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -851,7 +852,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
# Special case for per-tensor scale to load scalar into fused array. # Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if loaded_shard_id is None: if loaded_shard_id is None:
# Loaded weight is already fused on disk (mlp). # Loaded weight is already fused on disk (mlp).
...@@ -931,7 +931,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -931,7 +931,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = loaded_weight.shape[output_dim] * \ shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id loaded_shard_id
if not envs.VLLM_USE_NN or is_quantization: if not envs.VLLM_USE_NN or self.is_quantization:
param_data = param_data.narrow(output_dim, shard_offset, shard_size) param_data = param_data.narrow(output_dim, shard_offset, shard_size)
else: else:
param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size) param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size)
...@@ -953,7 +953,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -953,7 +953,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"MergedColumnParallelLinear, assume the weight is " "MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.") "the same for all partitions.")
if envs.VLLM_USE_NN and not is_quantization: if envs.VLLM_USE_NN and not self.is_quantization:
loaded_weight = loaded_weight.t() loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
...@@ -996,7 +996,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -996,7 +996,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param: BasevLLMParameter, param: BasevLLMParameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None): loaded_shard_id: Optional[int] = None):
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if loaded_shard_id is None: if loaded_shard_id is None:
if isinstance(param, PerTensorScaleParameter): if isinstance(param, PerTensorScaleParameter):
...@@ -1041,7 +1040,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -1041,7 +1040,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset=shard_offset, shard_offset=shard_offset,
shard_size=shard_size, shard_size=shard_size,
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
is_quantization=is_quantization) is_quantization=self.is_quantization)
class QKVParallelLinear(ColumnParallelLinear): class QKVParallelLinear(ColumnParallelLinear):
...@@ -1123,6 +1122,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1123,6 +1122,7 @@ class QKVParallelLinear(ColumnParallelLinear):
prefix=prefix, prefix=prefix,
return_bias=return_bias, return_bias=return_bias,
disable_tp=disable_tp) disable_tp=disable_tp)
self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
def _get_shard_offset_mapping(self, loaded_shard_id: str): def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = { shard_offset_mapping = {
...@@ -1181,7 +1181,6 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1181,7 +1181,6 @@ class QKVParallelLinear(ColumnParallelLinear):
param: BasevLLMParameter, param: BasevLLMParameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None): loaded_shard_id: Optional[str] = None):
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if loaded_shard_id is None: # special case for certain models if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter): if isinstance(param, PerTensorScaleParameter):
...@@ -1227,7 +1226,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1227,7 +1226,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset=shard_offset, shard_offset=shard_offset,
shard_size=shard_size, shard_size=shard_size,
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
is_quantization=is_quantization) is_quantization=self.is_quantization)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -1268,7 +1267,6 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1268,7 +1267,6 @@ class QKVParallelLinear(ColumnParallelLinear):
# Special case for per-tensor scales in fused case. # Special case for per-tensor scales in fused case.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if loaded_shard_id is None: if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv). # Loaded weight is already fused on disk (qkv).
...@@ -1375,7 +1373,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1375,7 +1373,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id) param, orig_qkv_offsets, loaded_shard_id)
if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization: if not envs.VLLM_USE_NN or len(param_data.shape)==1 or self.is_quantization:
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
else: else:
...@@ -1404,7 +1402,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1404,7 +1402,7 @@ class QKVParallelLinear(ColumnParallelLinear):
"QKVParallelLinear, assume the weight is the same " "QKVParallelLinear, assume the weight is the same "
"for all partitions.") "for all partitions.")
if envs.VLLM_USE_NN and not is_quantization: if envs.VLLM_USE_NN and not self.is_quantization:
loaded_weight = loaded_weight.t() loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
...@@ -1507,6 +1505,7 @@ class RowParallelLinear(LinearBase): ...@@ -1507,6 +1505,7 @@ class RowParallelLinear(LinearBase):
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.update_param_tp_status() self.update_param_tp_status()
self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
input_dim = getattr(param, "input_dim", None) input_dim = getattr(param, "input_dim", None)
...@@ -1529,12 +1528,10 @@ class RowParallelLinear(LinearBase): ...@@ -1529,12 +1528,10 @@ class RowParallelLinear(LinearBase):
weight_shape[input_dim] = (weight_shape[input_dim] // weight_shape[input_dim] = (weight_shape[input_dim] //
self.tp_size) self.tp_size)
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
param_data = param.data param_data = param.data
if input_dim is not None and not is_sharded_weight: if input_dim is not None and not is_sharded_weight:
if not envs.VLLM_USE_NN or is_quantization: if not envs.VLLM_USE_NN or self.is_quantization:
shard_size = param_data.shape[input_dim] shard_size = param_data.shape[input_dim]
else: else:
shard_size = param_data.shape[int(not(input_dim))] shard_size = param_data.shape[int(not(input_dim))]
...@@ -1547,7 +1544,7 @@ class RowParallelLinear(LinearBase): ...@@ -1547,7 +1544,7 @@ class RowParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0: if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
if envs.VLLM_USE_NN and not is_quantization: if envs.VLLM_USE_NN and not self.is_quantization:
loaded_weight = loaded_weight.t() loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
...@@ -1562,7 +1559,7 @@ class RowParallelLinear(LinearBase): ...@@ -1562,7 +1559,7 @@ class RowParallelLinear(LinearBase):
assert loaded_weight.numel() == 1 assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
param.load_row_parallel_weight(loaded_weight=loaded_weight.t() if envs.VLLM_USE_NN else loaded_weight) param.load_row_parallel_weight(loaded_weight=loaded_weight if not envs.VLLM_USE_NN or self.is_quantization else loaded_weight.t())
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment