Unverified Commit 52d42829 authored by maang's avatar maang Committed by GitHub
Browse files

[Core] Refactor ColumnParallelLinear: remove unused parameter and optimize forward (#31939)


Signed-off-by: default avatarmaang <maang_h@163.com>
parent c60578de
...@@ -411,10 +411,10 @@ class ReplicatedLinear(LinearBase): ...@@ -411,10 +411,10 @@ class ReplicatedLinear(LinearBase):
assert self.quant_method is not None assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias) output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias: if not self.return_bias:
return output return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias return output, output_bias
def extra_repr(self) -> str: def extra_repr(self) -> str:
...@@ -444,8 +444,6 @@ class ColumnParallelLinear(LinearBase): ...@@ -444,8 +444,6 @@ class ColumnParallelLinear(LinearBase):
skip adding bias but instead return it. skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
quant_config: Quantization configure. quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj) (e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass. return_bias: If true, return bias together with outputs in forward pass.
...@@ -463,7 +461,6 @@ class ColumnParallelLinear(LinearBase): ...@@ -463,7 +461,6 @@ class ColumnParallelLinear(LinearBase):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None, params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
output_sizes: list[int] | None = None,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
...@@ -495,9 +492,6 @@ class ColumnParallelLinear(LinearBase): ...@@ -495,9 +492,6 @@ class ColumnParallelLinear(LinearBase):
self._maybe_allow_fp8_block_shape_mismatch() self._maybe_allow_fp8_block_shape_mismatch()
self.gather_output = gather_output self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
assert self.quant_method is not None assert self.quant_method is not None
self.quant_method.create_weights( self.quant_method.create_weights(
layer=self, layer=self,
...@@ -614,9 +608,10 @@ class ColumnParallelLinear(LinearBase): ...@@ -614,9 +608,10 @@ class ColumnParallelLinear(LinearBase):
output = tensor_model_parallel_all_gather(output_parallel) output = tensor_model_parallel_all_gather(output_parallel)
else: else:
output = output_parallel output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias: if not self.return_bias:
return output return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias return output, output_bias
def extra_repr(self) -> str: def extra_repr(self) -> str:
...@@ -1469,10 +1464,9 @@ class RowParallelLinear(LinearBase): ...@@ -1469,10 +1464,9 @@ class RowParallelLinear(LinearBase):
else: else:
output = output_parallel output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias: if not self.return_bias:
return output return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias return output, output_bias
def extra_repr(self) -> str: def extra_repr(self) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment