Commit 6605af8e authored by zhuwenwen's avatar zhuwenwen
Browse files

update weight_loader_v2 layout of ColumnParallelLinear and MergedColumnParallelLinear

parent e8700643
...@@ -989,6 +989,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -989,6 +989,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param: BasevLLMParameter, param: BasevLLMParameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None): loaded_shard_id: Optional[int] = None):
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if loaded_shard_id is None: if loaded_shard_id is None:
if isinstance(param, PerTensorScaleParameter): if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight, param.load_merged_column_weight(loaded_weight=loaded_weight,
...@@ -1020,11 +1022,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -1020,11 +1022,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self.output_sizes[:loaded_shard_id]) // self.tp_size self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
param.load_merged_column_weight(loaded_weight=loaded_weight, if not envs.VLLM_USE_NN:
shard_id=loaded_shard_id, param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_offset=shard_offset, shard_id=loaded_shard_id,
shard_size=shard_size, shard_offset=shard_offset,
tp_rank=self.tp_rank) shard_size=shard_size,
tp_rank=self.tp_rank)
else:
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=self.tp_rank,
is_quantization=is_quantization)
class QKVParallelLinear(ColumnParallelLinear): class QKVParallelLinear(ColumnParallelLinear):
...@@ -1164,6 +1174,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1164,6 +1174,8 @@ class QKVParallelLinear(ColumnParallelLinear):
param: BasevLLMParameter, param: BasevLLMParameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None): loaded_shard_id: Optional[str] = None):
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if loaded_shard_id is None: # special case for certain models if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter): if isinstance(param, PerTensorScaleParameter):
param.load_qkv_weight(loaded_weight=loaded_weight, param.load_qkv_weight(loaded_weight=loaded_weight,
...@@ -1194,12 +1206,21 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1194,12 +1206,21 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = (shard_offset + block_n - 1) // block_n shard_offset = (shard_offset + block_n - 1) // block_n
shard_size = (shard_size + block_n - 1) // block_n shard_size = (shard_size + block_n - 1) // block_n
param.load_qkv_weight(loaded_weight=loaded_weight, if not envs.VLLM_USE_NN:
num_heads=self.num_kv_head_replicas, param.load_qkv_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id, num_heads=self.num_kv_head_replicas,
shard_offset=shard_offset, shard_id=loaded_shard_id,
shard_size=shard_size, shard_offset=shard_offset,
tp_rank=self.tp_rank) shard_size=shard_size,
tp_rank=self.tp_rank)
else:
param.load_qkv_weight(loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=self.tp_rank,
is_quantization=is_quantization)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -1534,7 +1555,7 @@ class RowParallelLinear(LinearBase): ...@@ -1534,7 +1555,7 @@ class RowParallelLinear(LinearBase):
assert loaded_weight.numel() == 1 assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
param.load_row_parallel_weight(loaded_weight=loaded_weight) param.load_row_parallel_weight(loaded_weight=loaded_weight.t() if envs.VLLM_USE_NN else loaded_weight)
def forward( def forward(
self, self,
......
...@@ -8,6 +8,8 @@ from weakref import WeakValueDictionary ...@@ -8,6 +8,8 @@ from weakref import WeakValueDictionary
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
import vllm.envs as envs
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
...@@ -150,6 +152,7 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -150,6 +152,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_offset = kwargs.get("shard_offset") shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size") shard_size = kwargs.get("shard_size")
is_quantization = kwargs.get("is_quantization")
# TODO: move these to PackedColumnParameter and PackedvLLMParameter # TODO: move these to PackedColumnParameter and PackedvLLMParameter
if isinstance( if isinstance(
...@@ -161,11 +164,19 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -161,11 +164,19 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data = self.data param_data = self.data
param_data = param_data.narrow(self.output_dim, shard_offset, if not envs.VLLM_USE_NN or is_quantization:
shard_size) param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(int(not(self.output_dim)), shard_offset,
shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim, loaded_weight = loaded_weight.narrow(self.output_dim,
self.tp_rank * shard_size, self.tp_rank * shard_size,
shard_size) shard_size)
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
...@@ -175,6 +186,7 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -175,6 +186,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_size = kwargs.get("shard_size") shard_size = kwargs.get("shard_size")
shard_id = kwargs.get("shard_id") shard_id = kwargs.get("shard_id")
num_heads = kwargs.get("num_heads") num_heads = kwargs.get("num_heads")
is_quantization = kwargs.get("is_quantization")
# TODO: move these to PackedColumnParameter and PackedvLLMParameter # TODO: move these to PackedColumnParameter and PackedvLLMParameter
if isinstance( if isinstance(
...@@ -187,11 +199,18 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -187,11 +199,18 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data = self.data param_data = self.data
shard_id = (self.tp_rank if shard_id == "q" else self.tp_rank // shard_id = (self.tp_rank if shard_id == "q" else self.tp_rank //
num_heads) num_heads)
param_data = param_data.narrow(self.output_dim, shard_offset, if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
shard_size) param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(int(not(self.output_dim)), shard_offset,
shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim, loaded_weight = loaded_weight.narrow(self.output_dim,
shard_id * shard_size, shard_size) shard_id * shard_size, shard_size)
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment