Unverified Commit 6737671c authored by Even Zhou's avatar Even Zhou Committed by GitHub
Browse files

[Bugfix] Fix w8a8_int8 import error on NPU (#8147)

parent fd63b62e
...@@ -754,6 +754,8 @@ class NPU_W8A8LinearMethod(LinearMethodBase): ...@@ -754,6 +754,8 @@ class NPU_W8A8LinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.linear import RowParallelLinear
if isinstance(layer, RowParallelLinear): if isinstance(layer, RowParallelLinear):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
return self.quant_method.apply(layer, x, bias, tp_rank) return self.quant_method.apply(layer, x, bias, tp_rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment