"vscode:/vscode.git/clone" did not exist on "584f5fb4c6d96365a3bfa8594115bc02744f2096"
Unverified Commit 5c6c54d6 authored by Dipika Sikka's avatar Dipika Sikka Committed by GitHub
Browse files

[Bugfix] Fix `PerTensorScaleParameter` weight loading for fused models (#7376)

parent 933790c2
...@@ -14,7 +14,8 @@ from vllm.logger import init_logger ...@@ -14,7 +14,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter) PackedvLLMParameter,
PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -573,11 +574,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -573,11 +574,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param: BasevLLMParameter, param: BasevLLMParameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None): loaded_shard_id: Optional[int] = None):
param_data = param.data
if loaded_shard_id is None: if loaded_shard_id is None:
if param.output_dim is None: if isinstance(param, PerTensorScaleParameter):
assert param_data.shape == loaded_weight.shape param.load_merged_column_weight(loaded_weight=loaded_weight,
param_data.copy_(loaded_weight) shard_id=0)
return
elif type(param) is BasevLLMParameter:
param.load_merged_column_weight(loaded_weight=loaded_weight)
return return
self._load_fused_module_from_checkpoint(param, loaded_weight) self._load_fused_module_from_checkpoint(param, loaded_weight)
return return
...@@ -720,11 +723,13 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -720,11 +723,13 @@ class QKVParallelLinear(ColumnParallelLinear):
param: BasevLLMParameter, param: BasevLLMParameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None): loaded_shard_id: Optional[str] = None):
param_data = param.data
if loaded_shard_id is None: # special case for certain models if loaded_shard_id is None: # special case for certain models
if param.output_dim is None: if isinstance(param, PerTensorScaleParameter):
assert param_data.shape == loaded_weight.shape param.load_merged_column_weight(loaded_weight=loaded_weight,
param_data.copy_(loaded_weight) shard_id=0)
return
elif type(param) is BasevLLMParameter:
param.load_merged_column_weight(loaded_weight=loaded_weight)
return return
self._load_fused_module_from_checkpoint(param, loaded_weight) self._load_fused_module_from_checkpoint(param, loaded_weight)
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment