Commit a9c37628 authored by zhuwenwen's avatar zhuwenwen
Browse files

init is_quantization of load_column_parallel_weight (BasevLLMParameter)

parent 3efb2e1c
......@@ -91,10 +91,10 @@ class BasevLLMParameter(Parameter):
or self._is_1d_and_scalar(loaded_weight))
self.data.copy_(loaded_weight)
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
def load_column_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = False):
self._assert_and_load(loaded_weight)
def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = None):
def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = False):
self._assert_and_load(loaded_weight)
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
......@@ -140,7 +140,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
def output_dim(self):
return self._output_dim
def load_column_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization:Optional[bool]):
def load_column_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = False):
if not envs.VLLM_USE_NN or len( self.data.shape)==1 or is_quantization:
shard_size = self.data.shape[self.output_dim]
else:
......@@ -238,7 +238,7 @@ class RowvLLMParameter(BasevLLMParameter):
def input_dim(self):
return self._input_dim
def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = None):
def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = False):
if not envs.VLLM_USE_NN or is_quantization:
shard_size = self.data.shape[self.input_dim]
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment