Commit f3731273 authored by zhuwenwen's avatar zhuwenwen
Browse files

update the layout of load_column_parallel_weight and load_row_parallel_weight

parent 5db8533c
......@@ -647,7 +647,7 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight=loaded_weight if not envs.VLLM_USE_NN or self.is_quantization else loaded_weight.t())
param.load_column_parallel_weight(loaded_weight=loaded_weight, is_quantization=self.is_quantization)
def forward(
self,
......@@ -835,7 +835,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
return
if is_gguf_weight:
output_dim = getattr(param, "output_dim", None)
shard_size = loaded_weight.size(output_dim) // self.tp_size
start_idx = self.tp_rank * shard_size
......@@ -1027,20 +1026,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = sum(
self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
if not envs.VLLM_USE_NN:
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=self.tp_rank)
else:
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=self.tp_rank,
is_quantization=self.is_quantization)
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=self.tp_rank,
is_quantization=self.is_quantization)
class QKVParallelLinear(ColumnParallelLinear):
......@@ -1212,21 +1204,13 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = (shard_offset + block_n - 1) // block_n
shard_size = (shard_size + block_n - 1) // block_n
if not envs.VLLM_USE_NN:
param.load_qkv_weight(loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=self.tp_rank)
else:
param.load_qkv_weight(loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=self.tp_rank,
is_quantization=self.is_quantization)
param.load_qkv_weight(loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=self.tp_rank,
is_quantization=self.is_quantization)
def weight_loader(self,
param: Parameter,
......@@ -1559,7 +1543,7 @@ class RowParallelLinear(LinearBase):
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_row_parallel_weight(loaded_weight=loaded_weight if not envs.VLLM_USE_NN or self.is_quantization else loaded_weight.t())
param.load_row_parallel_weight(loaded_weight=loaded_weight, is_quantization=self.is_quantization)
def forward(
self,
......
......@@ -140,11 +140,18 @@ class _ColumnvLLMParameter(BasevLLMParameter):
def output_dim(self):
return self._output_dim
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
shard_size = self.data.shape[self.output_dim]
def load_column_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization:Optional[bool]):
if not envs.VLLM_USE_NN or len( self.data.shape)==1 or is_quantization:
shard_size = self.data.shape[self.output_dim]
else:
shard_size = self.data.shape[int(not(self.output_dim))]
loaded_weight = loaded_weight.narrow(self.output_dim,
self.tp_rank * shard_size,
shard_size)
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
......@@ -231,8 +238,11 @@ class RowvLLMParameter(BasevLLMParameter):
def input_dim(self):
return self._input_dim
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
shard_size = self.data.shape[self.input_dim]
def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization:Optional[bool]):
if not envs.VLLM_USE_NN or is_quantization:
shard_size = self.data.shape[self.input_dim]
else:
shard_size = self.data.shape[int(not(self.input_dim))]
loaded_weight = loaded_weight.narrow(self.input_dim,
self.tp_rank * shard_size,
shard_size)
......@@ -240,6 +250,9 @@ class RowvLLMParameter(BasevLLMParameter):
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment