Commit 3efb2e1c authored by zhuwenwen's avatar zhuwenwen
Browse files

init is_quantization of load_row_parallel_weight (BasevLLMParameter)

parent 39380c86
...@@ -94,7 +94,7 @@ class BasevLLMParameter(Parameter): ...@@ -94,7 +94,7 @@ class BasevLLMParameter(Parameter):
def load_column_parallel_weight(self, loaded_weight: torch.Tensor): def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
self._assert_and_load(loaded_weight) self._assert_and_load(loaded_weight)
def load_row_parallel_weight(self, loaded_weight: torch.Tensor): def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = None):
self._assert_and_load(loaded_weight) self._assert_and_load(loaded_weight)
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
...@@ -238,7 +238,7 @@ class RowvLLMParameter(BasevLLMParameter): ...@@ -238,7 +238,7 @@ class RowvLLMParameter(BasevLLMParameter):
def input_dim(self): def input_dim(self):
return self._input_dim return self._input_dim
def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization:Optional[bool]): def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = None):
if not envs.VLLM_USE_NN or is_quantization: if not envs.VLLM_USE_NN or is_quantization:
shard_size = self.data.shape[self.input_dim] shard_size = self.data.shape[self.input_dim]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment