Commit 56ebbba3 authored by zhuwenwen's avatar zhuwenwen
Browse files

update linear of RowParallelLinear and UnquantizedLinearMethod apply

parent ee346e93
......@@ -250,7 +250,9 @@ class UnquantizedLinearMethod(LinearMethodBase):
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
output: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_llama_nn:
if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
layer.weight = layer.weight[:,:-32]
......@@ -265,8 +267,42 @@ class UnquantizedLinearMethod(LinearMethodBase):
else:
if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
return dispatch_unquantized_gemm()(x, layer.weight.t(), bias)
else:
weight = layer.weight
if residual is not None:
assert output is None or output is residual
if get_tensor_model_parallel_world_size(
) > 1 and get_tensor_model_parallel_rank() != 0:
beta = 0.0
else:
beta = 1.0
# optimize cuda memory usage
if x.dim() == 2:
torch.addmm(residual, x, weight.t(), beta=beta, out=residual)
elif x.dim() >= 3:
hx = x.size(-1)
hr = residual.size(-1)
torch.addmm(residual.view(-1, hr),
x.view(-1, hx),
weight.t(),
beta=beta,
out=residual.view(-1, hr))
else:
raise AssertionError(
"unrecognized tensor dimensions: {}".format(x.dim()))
if bias is not None:
residual += bias
return residual
else:
if output is not None:
if bias is not None: # always separate bias add when output is provided
torch.matmul(x, weight.t(), out=output)
output.add_(bias)
return output
return torch.matmul(x, weight.t(), out=output)
else:
return dispatch_unquantized_gemm()(x, layer.weight, bias)
# return dispatch_unquantized_gemm()(x, layer.weight, bias)
class UnquantizedMoELinearMethod(LinearMethodBase):
......@@ -633,7 +669,8 @@ class ColumnParallelLinear(LinearBase):
self, input_,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True
update_hd: Optional[bool] = True,
output: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None
......@@ -663,7 +700,7 @@ class ColumnParallelLinear(LinearBase):
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
output_parallel = self.quant_method.apply(self, input_, bias, output=output)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
......@@ -1703,6 +1740,9 @@ class RowParallelLinear(LinearBase):
def forward(
self, input_,
use_fused_silu_mul_quant: Optional[bool] = False,
residual=None,
output=None,
disable_allreduce=False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
......@@ -1712,7 +1752,14 @@ class RowParallelLinear(LinearBase):
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# only add residual to the first rank
if residual is not None and self.tp_size > 1 and get_tensor_model_parallel_rank(
) != 0:
residual *= 0
# Matrix multiply.
if output is not None:
assert disable_allreduce or not self.reduce_results
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
......@@ -1728,19 +1775,28 @@ class RowParallelLinear(LinearBase):
else:
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
residual=residual,
output=output)
if self.reduce_results and self.tp_size > 1 and not disable_allreduce:
if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel)
output_ = self.tbo_all_reduce(output_parallel)
else:
output = tensor_model_parallel_all_reduce(output_parallel)
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_ = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
else:
output = output_
output_bias = self.bias
# output_bias = self.bias if self.skip_bias_add else None
# if not self.return_bias:
# return output
if not self.return_bias:
return output
return output, output_bias
def extra_repr(self) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment