Commit a9c37628 authored by zhuwenwen's avatar zhuwenwen
Browse files

init is_quantization of load_column_parallel_weight (BasevLLMParameter)

parent 3efb2e1c
...@@ -91,10 +91,10 @@ class BasevLLMParameter(Parameter): ...@@ -91,10 +91,10 @@ class BasevLLMParameter(Parameter):
or self._is_1d_and_scalar(loaded_weight)) or self._is_1d_and_scalar(loaded_weight))
self.data.copy_(loaded_weight) self.data.copy_(loaded_weight)
def load_column_parallel_weight(self, loaded_weight: torch.Tensor): def load_column_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = False):
self._assert_and_load(loaded_weight) self._assert_and_load(loaded_weight)
def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = None): def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = False):
self._assert_and_load(loaded_weight) self._assert_and_load(loaded_weight)
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
...@@ -140,7 +140,7 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -140,7 +140,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
def output_dim(self): def output_dim(self):
return self._output_dim return self._output_dim
def load_column_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization:Optional[bool]): def load_column_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = False):
if not envs.VLLM_USE_NN or len( self.data.shape)==1 or is_quantization: if not envs.VLLM_USE_NN or len( self.data.shape)==1 or is_quantization:
shard_size = self.data.shape[self.output_dim] shard_size = self.data.shape[self.output_dim]
else: else:
...@@ -238,7 +238,7 @@ class RowvLLMParameter(BasevLLMParameter): ...@@ -238,7 +238,7 @@ class RowvLLMParameter(BasevLLMParameter):
def input_dim(self): def input_dim(self):
return self._input_dim return self._input_dim
def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = None): def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = False):
if not envs.VLLM_USE_NN or is_quantization: if not envs.VLLM_USE_NN or is_quantization:
shard_size = self.data.shape[self.input_dim] shard_size = self.data.shape[self.input_dim]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment