Unverified Commit bdcee0ff authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

fix triton kernel on the correct device (#1691)


Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>
parent 6d0a5cd2
......@@ -9,6 +9,8 @@ from . import triton_kernels
# from bitsandbytes.functional import get_4bit_type
# _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu")
# _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu")
device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
......@@ -21,7 +23,9 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
with torch_accelerator_module.device(A.device):
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
out = out.reshape(A.shape)
return out, absmax.float()
......@@ -35,6 +39,7 @@ def dequantize_blockwise(
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
out = torch.empty_like(A, dtype=dtype, device=A.device)
with torch_accelerator_module.device(A.device):
triton_kernels.dequant_int8_blockwise(
A,
code,
......@@ -55,6 +60,7 @@ def dequantize_blockwise_inplace(
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
with torch_accelerator_module.device(A.device):
triton_kernels.dequant_int8_blockwise(
A,
code,
......@@ -84,6 +90,7 @@ def quantize_4bit(
absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype)
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
with torch_accelerator_module.device(A.device):
triton_kernels.quantize_4bit_blockwise_triton(
A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
)
......@@ -119,7 +126,9 @@ def dequantize_4bit(
out = torch.empty(shape, dtype=dtype, device=A.device)
with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
return out
......@@ -134,6 +143,7 @@ def dequantize_4bit_inplace(
) -> None:
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
......@@ -150,6 +160,7 @@ def gemv_4bit(
B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device)
with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl_passing_code(
B,
absmax,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment