Unverified Commit bdcee0ff authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

fix triton kernel on the correct device (#1691)


Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>
parent 6d0a5cd2
...@@ -9,6 +9,8 @@ from . import triton_kernels ...@@ -9,6 +9,8 @@ from . import triton_kernels
# from bitsandbytes.functional import get_4bit_type # from bitsandbytes.functional import get_4bit_type
# _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu") # _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu")
# _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu") # _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu")
device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
...@@ -21,7 +23,9 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t ...@@ -21,7 +23,9 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
out = torch.empty_like(A.flatten(), dtype=torch.uint8) out = torch.empty_like(A.flatten(), dtype=torch.uint8)
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out) with torch_accelerator_module.device(A.device):
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
out = out.reshape(A.shape) out = out.reshape(A.shape)
return out, absmax.float() return out, absmax.float()
...@@ -35,13 +39,14 @@ def dequantize_blockwise( ...@@ -35,13 +39,14 @@ def dequantize_blockwise(
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}") # torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
out = torch.empty_like(A, dtype=dtype, device=A.device) out = torch.empty_like(A, dtype=dtype, device=A.device)
triton_kernels.dequant_int8_blockwise( with torch_accelerator_module.device(A.device):
A, triton_kernels.dequant_int8_blockwise(
code, A,
absmax, code,
out, absmax,
blocksize, out,
) blocksize,
)
return out return out
...@@ -55,13 +60,14 @@ def dequantize_blockwise_inplace( ...@@ -55,13 +60,14 @@ def dequantize_blockwise_inplace(
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
triton_kernels.dequant_int8_blockwise( with torch_accelerator_module.device(A.device):
A, triton_kernels.dequant_int8_blockwise(
code, A,
absmax, code,
out, absmax,
blocksize, out,
) blocksize,
)
def quantize_4bit( def quantize_4bit(
...@@ -84,9 +90,10 @@ def quantize_4bit( ...@@ -84,9 +90,10 @@ def quantize_4bit(
absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype) absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype)
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8) out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
triton_kernels.quantize_4bit_blockwise_triton( with torch_accelerator_module.device(A.device):
A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out triton_kernels.quantize_4bit_blockwise_triton(
) A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
)
packed = out packed = out
if quant_storage != torch.uint8: if quant_storage != torch.uint8:
...@@ -119,7 +126,9 @@ def dequantize_4bit( ...@@ -119,7 +126,9 @@ def dequantize_4bit(
out = torch.empty(shape, dtype=dtype, device=A.device) out = torch.empty(shape, dtype=dtype, device=A.device)
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
return out return out
...@@ -134,7 +143,8 @@ def dequantize_4bit_inplace( ...@@ -134,7 +143,8 @@ def dequantize_4bit_inplace(
) -> None: ) -> None:
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
def gemv_4bit( def gemv_4bit(
...@@ -150,14 +160,15 @@ def gemv_4bit( ...@@ -150,14 +160,15 @@ def gemv_4bit(
B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device) B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device)
triton_kernels._dequantize_4bit_impl_passing_code( with torch_accelerator_module.device(A.device):
B, triton_kernels._dequantize_4bit_impl_passing_code(
absmax, B,
blocksize, absmax,
code, blocksize,
dtype=A.dtype, code,
out=B_dq_triton, dtype=A.dtype,
) out=B_dq_triton,
)
return torch.nn.functional.linear( return torch.nn.functional.linear(
A, A,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment