Unverified Commit 52d42829 authored by maang's avatar maang Committed by GitHub
Browse files

[Core] Refactor ColumnParallelLinear: remove unused parameter and optimize forward (#31939)


Signed-off-by: default avatarmaang <maang_h@163.com>
parent c60578de
......@@ -411,10 +411,10 @@ class ReplicatedLinear(LinearBase):
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
......@@ -444,8 +444,6 @@ class ColumnParallelLinear(LinearBase):
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
......@@ -463,7 +461,6 @@ class ColumnParallelLinear(LinearBase):
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
output_sizes: list[int] | None = None,
prefix: str = "",
*,
return_bias: bool = True,
......@@ -495,9 +492,6 @@ class ColumnParallelLinear(LinearBase):
self._maybe_allow_fp8_block_shape_mismatch()
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
......@@ -614,9 +608,10 @@ class ColumnParallelLinear(LinearBase):
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
......@@ -1469,10 +1464,9 @@ class RowParallelLinear(LinearBase):
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment