"vscode:/vscode.git/clone" did not exist on "b1ddea7fd94c52a7be76cec721d32d438681af83"
Commit 56ebbba3 authored by zhuwenwen's avatar zhuwenwen
Browse files

update linear of RowParallelLinear and UnquantizedLinearMethod apply

parent ee346e93
...@@ -250,7 +250,9 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -250,7 +250,9 @@ class UnquantizedLinearMethod(LinearMethodBase):
def apply(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
output: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_llama_nn: if self.use_llama_nn:
if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1': if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
layer.weight = layer.weight[:,:-32] layer.weight = layer.weight[:,:-32]
...@@ -265,8 +267,42 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -265,8 +267,42 @@ class UnquantizedLinearMethod(LinearMethodBase):
else: else:
if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]: if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
return dispatch_unquantized_gemm()(x, layer.weight.t(), bias) return dispatch_unquantized_gemm()(x, layer.weight.t(), bias)
else:
weight = layer.weight
if residual is not None:
assert output is None or output is residual
if get_tensor_model_parallel_world_size(
) > 1 and get_tensor_model_parallel_rank() != 0:
beta = 0.0
else:
beta = 1.0
# optimize cuda memory usage
if x.dim() == 2:
torch.addmm(residual, x, weight.t(), beta=beta, out=residual)
elif x.dim() >= 3:
hx = x.size(-1)
hr = residual.size(-1)
torch.addmm(residual.view(-1, hr),
x.view(-1, hx),
weight.t(),
beta=beta,
out=residual.view(-1, hr))
else:
raise AssertionError(
"unrecognized tensor dimensions: {}".format(x.dim()))
if bias is not None:
residual += bias
return residual
else:
if output is not None:
if bias is not None: # always separate bias add when output is provided
torch.matmul(x, weight.t(), out=output)
output.add_(bias)
return output
return torch.matmul(x, weight.t(), out=output)
else: else:
return dispatch_unquantized_gemm()(x, layer.weight, bias) return dispatch_unquantized_gemm()(x, layer.weight, bias)
# return dispatch_unquantized_gemm()(x, layer.weight, bias)
class UnquantizedMoELinearMethod(LinearMethodBase): class UnquantizedMoELinearMethod(LinearMethodBase):
...@@ -633,7 +669,8 @@ class ColumnParallelLinear(LinearBase): ...@@ -633,7 +669,8 @@ class ColumnParallelLinear(LinearBase):
self, input_, self, input_,
rms_weight: Optional[torch.Tensor] = None, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True update_hd: Optional[bool] = True,
output: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None: if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None input_quant_args = None
...@@ -663,7 +700,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -663,7 +700,7 @@ class ColumnParallelLinear(LinearBase):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
# Matrix multiply. # Matrix multiply.
assert self.quant_method is not None assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias) output_parallel = self.quant_method.apply(self, input_, bias, output=output)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel) output = tensor_model_parallel_all_gather(output_parallel)
...@@ -1703,6 +1740,9 @@ class RowParallelLinear(LinearBase): ...@@ -1703,6 +1740,9 @@ class RowParallelLinear(LinearBase):
def forward( def forward(
self, input_, self, input_,
use_fused_silu_mul_quant: Optional[bool] = False, use_fused_silu_mul_quant: Optional[bool] = False,
residual=None,
output=None,
disable_allreduce=False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
...@@ -1712,7 +1752,14 @@ class RowParallelLinear(LinearBase): ...@@ -1712,7 +1752,14 @@ class RowParallelLinear(LinearBase):
input_, num_partitions=self.tp_size) input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous() input_parallel = splitted_input[tp_rank].contiguous()
# only add residual to the first rank
if residual is not None and self.tp_size > 1 and get_tensor_model_parallel_rank(
) != 0:
residual *= 0
# Matrix multiply. # Matrix multiply.
if output is not None:
assert disable_allreduce or not self.reduce_results
assert self.quant_method is not None assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that # Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case) # bias will not get added more than once in TP>1 case)
...@@ -1728,19 +1775,28 @@ class RowParallelLinear(LinearBase): ...@@ -1728,19 +1775,28 @@ class RowParallelLinear(LinearBase):
else: else:
output_parallel = self.quant_method.apply(self, output_parallel = self.quant_method.apply(self,
input_parallel, input_parallel,
bias=bias_) residual=residual,
if self.reduce_results and self.tp_size > 1: output=output)
if self.reduce_results and self.tp_size > 1 and not disable_allreduce:
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel) output_ = self.tbo_all_reduce(output_parallel)
else: else:
output = tensor_model_parallel_all_reduce(output_parallel) output_ = tensor_model_parallel_all_reduce(output_parallel)
else: else:
output = output_parallel output_ = output_parallel
output_bias = self.bias if self.skip_bias_add else None if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
else:
output = output_
output_bias = self.bias
# output_bias = self.bias if self.skip_bias_add else None
# if not self.return_bias:
# return output
if not self.return_bias:
return output
return output, output_bias return output, output_bias
def extra_repr(self) -> str: def extra_repr(self) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment