Unverified Commit a09d05a0 authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

add int mm for xpu after torch 2.9 (#1736)



* add int mm for xpu after torch 2.9
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* add packaging on pyproject
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

---------
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>
parent c76e208f
from collections.abc import Sequence from collections.abc import Sequence
import warnings import warnings
from packaging import version
import torch import torch
from ..._ops import register_kernel from ..._ops import register_kernel
from ..utils import ipex_xpu, triton_available from ..utils import ipex_xpu, triton_available
# _int_mm is available in torch starting from 2.7 version, # _int_mm is available in torch starting from 2.9 version, or ipex 2.7
# but currently it's don't have xpu implementation. if version.parse(torch.__version__).release >= version.parse("2.9").release or (
if ipex_xpu and torch.__version__ >= (2, 7): ipex_xpu and torch.__version__ >= (2, 7)
):
@register_kernel("bitsandbytes::int8_linear_matmul", "xpu") @register_kernel("bitsandbytes::int8_linear_matmul", "xpu")
def _(A: torch.Tensor, B: torch.Tensor): def _(A: torch.Tensor, B: torch.Tensor):
......
...@@ -43,7 +43,8 @@ classifiers = [ ...@@ -43,7 +43,8 @@ classifiers = [
] ]
dependencies = [ dependencies = [
"torch>=2.2,<3", "torch>=2.2,<3",
"numpy>=1.17" "numpy>=1.17",
"packaging>=20.9"
] ]
[project.urls] [project.urls]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment