Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
bdcee0ff
Unverified
Commit
bdcee0ff
authored
Jul 01, 2025
by
jiqing-feng
Committed by
GitHub
Jul 01, 2025
Browse files
fix triton kernel on the correct device (#1691)
Signed-off-by:
jiqing-feng
<
jiqing.feng@intel.com
>
parent
6d0a5cd2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
28 deletions
+39
-28
bitsandbytes/backends/triton/ops.py
bitsandbytes/backends/triton/ops.py
+39
-28
No files found.
bitsandbytes/backends/triton/ops.py
View file @
bdcee0ff
...
@@ -9,6 +9,8 @@ from . import triton_kernels
...
@@ -9,6 +9,8 @@ from . import triton_kernels
# from bitsandbytes.functional import get_4bit_type
# from bitsandbytes.functional import get_4bit_type
# _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu")
# _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu")
# _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu")
# _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu")
device_type
=
torch
.
accelerator
.
current_accelerator
().
type
if
hasattr
(
torch
,
"accelerator"
)
else
"cuda"
torch_accelerator_module
=
getattr
(
torch
,
device_type
,
torch
.
cuda
)
def
quantize_blockwise
(
A
:
torch
.
Tensor
,
code
:
torch
.
Tensor
,
blocksize
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
quantize_blockwise
(
A
:
torch
.
Tensor
,
code
:
torch
.
Tensor
,
blocksize
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
@@ -21,7 +23,9 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t
...
@@ -21,7 +23,9 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t
absmax
=
torch
.
empty
((
blocks
,),
device
=
A
.
device
,
dtype
=
A
.
dtype
)
absmax
=
torch
.
empty
((
blocks
,),
device
=
A
.
device
,
dtype
=
A
.
dtype
)
out
=
torch
.
empty_like
(
A
.
flatten
(),
dtype
=
torch
.
uint8
)
out
=
torch
.
empty_like
(
A
.
flatten
(),
dtype
=
torch
.
uint8
)
with
torch_accelerator_module
.
device
(
A
.
device
):
triton_kernels
.
quantize_blockwise_triton
(
A
,
blocksize
,
code
,
blocks
,
absmax
,
out
)
triton_kernels
.
quantize_blockwise_triton
(
A
,
blocksize
,
code
,
blocks
,
absmax
,
out
)
out
=
out
.
reshape
(
A
.
shape
)
out
=
out
.
reshape
(
A
.
shape
)
return
out
,
absmax
.
float
()
return
out
,
absmax
.
float
()
...
@@ -35,6 +39,7 @@ def dequantize_blockwise(
...
@@ -35,6 +39,7 @@ def dequantize_blockwise(
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
out
=
torch
.
empty_like
(
A
,
dtype
=
dtype
,
device
=
A
.
device
)
out
=
torch
.
empty_like
(
A
,
dtype
=
dtype
,
device
=
A
.
device
)
with
torch_accelerator_module
.
device
(
A
.
device
):
triton_kernels
.
dequant_int8_blockwise
(
triton_kernels
.
dequant_int8_blockwise
(
A
,
A
,
code
,
code
,
...
@@ -55,6 +60,7 @@ def dequantize_blockwise_inplace(
...
@@ -55,6 +60,7 @@ def dequantize_blockwise_inplace(
torch
.
_check
(
out
.
device
==
A
.
device
,
lambda
:
f
"Expected out.device ==
{
A
.
device
}
, got
{
out
.
device
}
"
)
torch
.
_check
(
out
.
device
==
A
.
device
,
lambda
:
f
"Expected out.device ==
{
A
.
device
}
, got
{
out
.
device
}
"
)
torch
.
_check
(
out
.
dtype
==
dtype
,
lambda
:
f
"Expected out.dtype ==
{
dtype
}
, got
{
out
.
dtype
}
"
)
torch
.
_check
(
out
.
dtype
==
dtype
,
lambda
:
f
"Expected out.dtype ==
{
dtype
}
, got
{
out
.
dtype
}
"
)
with
torch_accelerator_module
.
device
(
A
.
device
):
triton_kernels
.
dequant_int8_blockwise
(
triton_kernels
.
dequant_int8_blockwise
(
A
,
A
,
code
,
code
,
...
@@ -84,6 +90,7 @@ def quantize_4bit(
...
@@ -84,6 +90,7 @@ def quantize_4bit(
absmax
=
torch
.
empty
((
blocks
*
2
,),
device
=
A
.
device
,
dtype
=
A
.
dtype
)
absmax
=
torch
.
empty
((
blocks
*
2
,),
device
=
A
.
device
,
dtype
=
A
.
dtype
)
out
=
torch
.
empty
((
n
//
2
,
1
),
device
=
A
.
device
,
dtype
=
torch
.
uint8
)
out
=
torch
.
empty
((
n
//
2
,
1
),
device
=
A
.
device
,
dtype
=
torch
.
uint8
)
with
torch_accelerator_module
.
device
(
A
.
device
):
triton_kernels
.
quantize_4bit_blockwise_triton
(
triton_kernels
.
quantize_4bit_blockwise_triton
(
A
,
blocksize
,
quant_type
,
blocks
,
absmax
,
num_elements
=
n
,
quantized_out
=
out
A
,
blocksize
,
quant_type
,
blocks
,
absmax
,
num_elements
=
n
,
quantized_out
=
out
)
)
...
@@ -119,7 +126,9 @@ def dequantize_4bit(
...
@@ -119,7 +126,9 @@ def dequantize_4bit(
out
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
A
.
device
)
out
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
A
.
device
)
with
torch_accelerator_module
.
device
(
A
.
device
):
triton_kernels
.
_dequantize_4bit_impl
(
A
,
absmax
,
blocksize
,
quant_type
,
dtype
,
out
=
out
)
triton_kernels
.
_dequantize_4bit_impl
(
A
,
absmax
,
blocksize
,
quant_type
,
dtype
,
out
=
out
)
return
out
return
out
...
@@ -134,6 +143,7 @@ def dequantize_4bit_inplace(
...
@@ -134,6 +143,7 @@ def dequantize_4bit_inplace(
)
->
None
:
)
->
None
:
torch
.
_check
(
out
.
shape
==
shape
,
lambda
:
f
"Expected out.shape ==
{
shape
}
, got
{
out
.
shape
}
"
)
torch
.
_check
(
out
.
shape
==
shape
,
lambda
:
f
"Expected out.shape ==
{
shape
}
, got
{
out
.
shape
}
"
)
torch
.
_check
(
out
.
dtype
==
dtype
,
lambda
:
f
"Expected out.dtype ==
{
dtype
}
, got
{
out
.
dtype
}
"
)
torch
.
_check
(
out
.
dtype
==
dtype
,
lambda
:
f
"Expected out.dtype ==
{
dtype
}
, got
{
out
.
dtype
}
"
)
with
torch_accelerator_module
.
device
(
A
.
device
):
triton_kernels
.
_dequantize_4bit_impl
(
A
,
absmax
,
blocksize
,
quant_type
,
dtype
,
out
=
out
)
triton_kernels
.
_dequantize_4bit_impl
(
A
,
absmax
,
blocksize
,
quant_type
,
dtype
,
out
=
out
)
...
@@ -150,6 +160,7 @@ def gemv_4bit(
...
@@ -150,6 +160,7 @@ def gemv_4bit(
B_dq_triton
=
torch
.
empty
(
shapeB
,
dtype
=
A
.
dtype
,
device
=
A
.
device
)
B_dq_triton
=
torch
.
empty
(
shapeB
,
dtype
=
A
.
dtype
,
device
=
A
.
device
)
with
torch_accelerator_module
.
device
(
A
.
device
):
triton_kernels
.
_dequantize_4bit_impl_passing_code
(
triton_kernels
.
_dequantize_4bit_impl_passing_code
(
B
,
B
,
absmax
,
absmax
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment