Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
e67bfccb
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "e1bdc23dd27cc615ae5ab2334429c05e241d3a02"
Commit
e67bfccb
authored
Apr 12, 2023
by
Tim Dettmers
Browse files
Added missing triton and fp8 files.
parent
ec1ea637
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1104 additions
and
0 deletions
+1104
-0
bitsandbytes/research/__init__.py
bitsandbytes/research/__init__.py
+7
-0
bitsandbytes/research/autograd/__init__.py
bitsandbytes/research/autograd/__init__.py
+0
-0
bitsandbytes/research/autograd/_functions.py
bitsandbytes/research/autograd/_functions.py
+493
-0
bitsandbytes/triton/__init__.py
bitsandbytes/triton/__init__.py
+0
-0
bitsandbytes/triton/dequantize_rowwise.py
bitsandbytes/triton/dequantize_rowwise.py
+58
-0
bitsandbytes/triton/int8_matmul_mixed_dequanitze.py
bitsandbytes/triton/int8_matmul_mixed_dequanitze.py
+158
-0
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
+159
-0
bitsandbytes/triton/quantize_columnwise_and_transpose.py
bitsandbytes/triton/quantize_columnwise_and_transpose.py
+68
-0
bitsandbytes/triton/quantize_global.py
bitsandbytes/triton/quantize_global.py
+100
-0
bitsandbytes/triton/quantize_rowwise.py
bitsandbytes/triton/quantize_rowwise.py
+61
-0
No files found.
bitsandbytes/research/__init__.py
0 → 100644
View file @
e67bfccb
from
.autograd._functions
import
(
matmul_fp8
,
switchback_bnb
,
matmul_fp8_global
,
matmul_fp8_mixed
,
)
bitsandbytes/research/autograd/__init__.py
0 → 100644
View file @
e67bfccb
bitsandbytes/research/autograd/_functions.py
0 → 100644
View file @
e67bfccb
import
operator
import
warnings
from
dataclasses
import
dataclass
from
functools
import
reduce
# Required in Python 3
import
torch
import
bitsandbytes.functional
as
F
from
bitsandbytes.autograd._functions
import
MatmulLtState
,
GlobalOutlierPooler
# math.prod not compatible with python < 3.8
def
prod
(
iterable
):
return
reduce
(
operator
.
mul
,
iterable
,
1
)
tensor
=
torch
.
Tensor
class
MatMulFP8
(
torch
.
autograd
.
Function
):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
fw_code
=
None
,
bw_code
=
None
,
bsz
=
1024
,
bsz2
=
1024
):
# default of pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
if
prod
(
A
.
shape
)
==
0
:
ctx
.
is_empty
=
True
ctx
.
A
=
A
ctx
.
B
=
B
B_shape
=
B
.
shape
if
A
.
shape
[
-
1
]
==
B_shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[:
1
],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
# 1. Dequantize
# 2. MatmulnN
cA
,
state
=
F
.
quantize_blockwise
(
A
,
code
=
fw_code
,
blocksize
=
bsz
)
fp8A
=
F
.
dequantize_blockwise
(
cA
,
state
,
blocksize
=
bsz
).
to
(
A
.
dtype
)
cB
,
state
=
F
.
quantize
(
B
.
float
(),
code
=
fw_code
)
fp8B
=
F
.
dequantize
(
cB
,
state
).
to
(
B
.
dtype
)
output
=
torch
.
matmul
(
fp8A
,
fp8B
)
# output is half
# 3. Save state
ctx
.
fw_code
=
fw_code
ctx
.
bw_code
=
bw_code
ctx
.
bsz
=
bsz
ctx
.
bsz2
=
bsz2
ctx
.
dtype_A
,
ctx
.
dtype_B
=
A
.
dtype
,
B
.
dtype
if
any
(
ctx
.
needs_input_grad
[:
2
]):
# NOTE: we send back A, and re-quant.
ctx
.
tensors
=
(
A
,
fp8B
)
else
:
ctx
.
tensors
=
(
None
,
None
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
None
,
None
,
None
,
None
req_gradA
,
req_gradB
,
_
,
_
,
_
,
_
,
_
=
ctx
.
needs_input_grad
A
,
B
=
ctx
.
tensors
grad_A
,
grad_B
=
None
,
None
cgrad_out
,
state
=
F
.
quantize_blockwise
(
grad_output
,
code
=
ctx
.
bw_code
,
blocksize
=
ctx
.
bsz2
)
fp8out
=
F
.
dequantize_blockwise
(
cgrad_out
,
state
,
blocksize
=
ctx
.
bsz2
).
to
(
grad_output
.
dtype
)
cgrad_output_2
,
state_2
=
F
.
quantize
(
grad_output
.
float
(),
code
=
ctx
.
bw_code
)
fp8out_2
=
F
.
dequantize
(
cgrad_output_2
,
state_2
).
to
(
grad_output
.
dtype
)
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
# not supported by PyTorch. TODO: create work-around
if
req_gradA
:
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
().
to
(
fp8out
.
dtype
)).
to
(
A
.
dtype
)
if
req_gradB
:
if
len
(
A
.
shape
)
==
3
:
At
=
A
.
transpose
(
2
,
1
).
contiguous
()
else
:
At
=
A
.
transpose
(
1
,
0
).
contiguous
()
cA
,
state
=
F
.
quantize
(
At
.
float
(),
code
=
ctx
.
fw_code
)
fp8At
=
F
.
dequantize
(
cA
,
state
).
to
(
A
.
dtype
)
grad_B
=
torch
.
matmul
(
fp8At
.
to
(
fp8out_2
.
dtype
),
fp8out_2
).
to
(
B
.
dtype
)
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
class
MatMulFP8Mixed
(
torch
.
autograd
.
Function
):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
fw_code
=
None
,
bw_code
=
None
,
bsz
=
1024
,
bsz2
=
1024
):
# default of pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
if
prod
(
A
.
shape
)
==
0
:
ctx
.
is_empty
=
True
ctx
.
A
=
A
ctx
.
B
=
B
B_shape
=
B
.
shape
if
A
.
shape
[
-
1
]
==
B_shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[:
1
],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
# 1. Dequantize
# 2. MatmulnN
cA
,
state
=
F
.
quantize_blockwise
(
A
,
code
=
fw_code
,
blocksize
=
bsz
)
fp8A
=
F
.
dequantize_blockwise
(
cA
,
state
,
blocksize
=
bsz
).
to
(
A
.
dtype
)
cB
,
state
=
F
.
quantize
(
B
.
float
(),
code
=
fw_code
)
fp8B
=
F
.
dequantize
(
cB
,
state
).
to
(
B
.
dtype
)
output
=
torch
.
matmul
(
fp8A
,
fp8B
)
# output is half
# 3. Save state
ctx
.
fw_code
=
fw_code
ctx
.
bw_code
=
bw_code
ctx
.
bsz
=
bsz
ctx
.
bsz2
=
bsz2
ctx
.
dtype_A
,
ctx
.
dtype_B
=
A
.
dtype
,
B
.
dtype
if
any
(
ctx
.
needs_input_grad
[:
2
]):
# NOTE: we send back A, and re-quant.
ctx
.
tensors
=
(
A
,
fp8B
)
else
:
ctx
.
tensors
=
(
None
,
None
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
None
,
None
,
None
,
None
req_gradA
,
req_gradB
,
_
,
_
,
_
,
_
,
_
=
ctx
.
needs_input_grad
A
,
B
=
ctx
.
tensors
grad_A
,
grad_B
=
None
,
None
# TODO: Fix blocksize to be output_dim
cgrad_out
,
state
=
F
.
quantize_blockwise
(
grad_output
,
code
=
ctx
.
bw_code
,
blocksize
=
ctx
.
bsz2
)
fp8out
=
F
.
dequantize_blockwise
(
cgrad_out
,
state
,
blocksize
=
ctx
.
bsz2
).
to
(
grad_output
.
dtype
)
# cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
# fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
# not supported by PyTorch. TODO: create work-around
if
req_gradA
:
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
().
to
(
fp8out
.
dtype
)).
to
(
A
.
dtype
)
if
req_gradB
:
At
=
A
.
transpose
(
2
,
1
).
contiguous
()
# cA, state = F.quantize(At.float(), code=ctx.fw_code)
# fp8At = F.dequantize(cA, state).to(A.dtype)
grad_B
=
torch
.
matmul
(
At
.
to
(
grad_output
.
dtype
),
grad_output
).
to
(
B
.
dtype
)
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
class
MatMulFP8Global
(
torch
.
autograd
.
Function
):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
fw_code
=
None
,
bw_code
=
None
,
bsz
=
1024
,
bsz2
=
1024
):
# default of pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
if
prod
(
A
.
shape
)
==
0
:
ctx
.
is_empty
=
True
ctx
.
A
=
A
ctx
.
B
=
B
B_shape
=
B
.
shape
if
A
.
shape
[
-
1
]
==
B_shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[:
1
],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
# 1. Dequantize
# 2. MatmulnN
cA
,
state
=
F
.
quantize
(
A
.
float
(),
code
=
fw_code
)
fp8A
=
F
.
dequantize
(
cA
,
state
).
to
(
A
.
dtype
)
cB
,
state
=
F
.
quantize
(
B
.
float
(),
code
=
fw_code
)
fp8B
=
F
.
dequantize
(
cB
,
state
).
to
(
B
.
dtype
)
output
=
torch
.
matmul
(
fp8A
,
fp8B
)
# output is half
# 3. Save state
ctx
.
fw_code
=
fw_code
ctx
.
bw_code
=
bw_code
ctx
.
bsz
=
bsz
ctx
.
bsz2
=
bsz2
ctx
.
dtype_A
,
ctx
.
dtype_B
=
A
.
dtype
,
B
.
dtype
if
any
(
ctx
.
needs_input_grad
[:
2
]):
# NOTE: we send back A, and re-quant.
ctx
.
tensors
=
(
A
,
fp8B
)
else
:
ctx
.
tensors
=
(
None
,
None
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
None
,
None
,
None
,
None
req_gradA
,
req_gradB
,
_
,
_
,
_
,
_
,
_
=
ctx
.
needs_input_grad
A
,
B
=
ctx
.
tensors
grad_A
,
grad_B
=
None
,
None
# TODO: Fix blocksize to be output_dim
cgrad_out
,
state
=
F
.
quantize
(
grad_output
.
float
(),
code
=
ctx
.
bw_code
)
fp8out
=
F
.
dequantize
(
cgrad_out
,
state
).
to
(
grad_output
.
dtype
)
# cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
# fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
# not supported by PyTorch. TODO: create work-around
if
req_gradA
:
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
().
to
(
fp8out
.
dtype
)).
to
(
A
.
dtype
)
if
req_gradB
:
At
=
A
.
transpose
(
2
,
1
).
contiguous
()
cA
,
state
=
F
.
quantize
(
At
.
float
(),
code
=
ctx
.
fw_code
)
fp8At
=
F
.
dequantize
(
cA
,
state
).
to
(
A
.
dtype
)
grad_B
=
torch
.
matmul
(
fp8At
.
to
(
fp8out
.
dtype
),
fp8out
).
to
(
B
.
dtype
)
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
class
MatMul8bitMixed
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
bias
=
None
,
state
=
MatmulLtState
()):
# default to pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
if
prod
(
A
.
shape
)
==
0
:
ctx
.
is_empty
=
True
ctx
.
A
=
A
ctx
.
B
=
B
ctx
.
bias
=
bias
if
A
.
shape
[
-
1
]
==
B
.
shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[:
1
],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
formatB
=
state
.
formatB
input_shape
=
A
.
shape
if
state
.
outlier_pool
is
None
:
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
# Cast A to fp16
if
A
.
dtype
!=
torch
.
float16
:
warnings
.
warn
(
f
"MatMul8bitLt: inputs will be cast from
{
A
.
dtype
}
to float16 during quantization"
)
# 1. Quantize A
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
.
to
(
torch
.
float16
),
threshold
=
state
.
threshold
)
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
if
state
.
has_fp16_weights
:
idx
=
torch
.
unique
(
coo_tensorA
.
colidx
).
long
()
CA
[:,
idx
]
=
0
CAt
[:,
idx
]
=
0
subA
=
A
[:,
idx
]
state
.
subB
=
B
[:,
idx
].
t
().
contiguous
()
state
.
idx
=
idx
else
:
if
state
.
CxB
is
None
:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
else
:
#print('A shape', A.shape)
if
not
state
.
has_fp16_weights
and
state
.
CxB
is
None
:
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
subA
=
None
# 2. Quantize B
if
state
.
has_fp16_weights
:
#print('B shape', B.shape)
has_grad
=
True
if
(
getattr
(
B
,
"grad"
,
None
)
is
not
None
)
else
False
is_transposed
=
not
B
.
is_contiguous
()
and
B
.
shape
[
0
]
==
B
.
stride
(
1
)
if
is_transposed
:
B
=
B
.
contiguous
()
if
(
state
.
is_training
and
not
has_grad
)
or
state
.
CxB
is
None
:
state
.
reset_grads
()
(
CB
,
state
.
CBt
,
state
.
SCB
,
state
.
SCBt
,
coo_tensorB
,
)
=
F
.
double_quant
(
B
.
to
(
torch
.
float16
))
state
.
CxB
,
state
.
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
else
:
has_grad
=
False
if
coo_tensorA
is
not
None
and
not
state
.
has_fp16_weights
:
# extract outliers
outlier_idx
=
torch
.
unique
(
coo_tensorA
.
colidx
)
state
.
idx
=
outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
outliers
=
F
.
extract_outliers
(
state
.
CxB
,
state
.
SB
,
state
.
idx
.
int
())
state
.
subB
=
(
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
)
.
t
()
.
contiguous
()
.
to
(
A
.
dtype
)
)
CA
[:,
state
.
idx
.
long
()]
=
0
CAt
[:,
state
.
idx
.
long
()]
=
0
subA
=
A
[:,
state
.
idx
.
long
()]
shapeB
=
state
.
SB
[
0
]
if
len
(
input_shape
)
==
3
:
output_shape
=
(
input_shape
[
0
],
input_shape
[
1
],
shapeB
[
0
])
else
:
output_shape
=
(
input_shape
[
0
],
shapeB
[
0
])
# 3. Matmul
C32A
,
SA
=
F
.
transform
(
CA
,
"col32"
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
# we apply the fused bias here
if
bias
is
None
or
bias
.
dtype
==
torch
.
float16
:
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
bias
)
output
=
output
.
to
(
A
.
dtype
)
else
:
# apply bias separately
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
None
)
output
=
output
.
to
(
A
.
dtype
).
add_
(
bias
)
# 4. Mixed-precision decomposition matmul
if
coo_tensorA
is
not
None
and
subA
is
not
None
:
output
+=
torch
.
matmul
(
subA
,
state
.
subB
)
# 5. Save state
ctx
.
state
=
state
ctx
.
formatB
=
formatB
ctx
.
grad_shape
=
input_shape
ctx
.
dtype_A
,
ctx
.
dtype_B
,
ctx
.
dtype_bias
=
A
.
dtype
,
B
.
dtype
,
None
if
bias
is
None
else
bias
.
dtype
if
any
(
ctx
.
needs_input_grad
[:
2
]):
ctx
.
tensors
=
(
CAt
,
subA
,
A
)
ctx
.
tensor_states
=
(
SCAt
,
state
.
idx
)
else
:
ctx
.
tensors
=
[
None
,
None
,
None
]
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
clone_func
=
torch
.
clone
if
len
(
output_shape
)
==
3
else
lambda
x
:
x
return
clone_func
(
output
.
view
(
output_shape
))
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
bias_grad
=
(
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
))
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
req_gradA
,
req_gradB
,
_
,
req_gradBias
,
_
=
ctx
.
needs_input_grad
CAt
,
subA
,
A
=
ctx
.
tensors
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
state
=
ctx
.
state
grad_A
=
grad_B
=
grad_bias
=
None
if
req_gradBias
:
# compute grad_bias first before changing grad_output dtype
grad_bias
=
grad_output
.
sum
(
0
,
dtype
=
ctx
.
dtype_bias
)
# Cast grad_output to fp16
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
reshape
(
-
1
,
grad_output
.
shape
[
-
1
]
).
contiguous
()
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
.
to
(
torch
.
float16
))
if
req_gradB
:
# print('back A shape', A.shape)
# print('grad output t shape', grad_output.t().shape)
grad_B
=
torch
.
matmul
(
grad_output
.
t
(),
A
)
if
req_gradA
:
if
state
.
CBt
is
not
None
:
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
"col32"
)
if
state
.
CxBt
is
None
:
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
# print('back B shape', state.CxBt.shape)
# print('back grad shape', C32grad.shape)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
dtype_A
)
elif
state
.
CB
is
not
None
:
CB
=
state
.
CB
.
to
(
ctx
.
dtype_A
,
copy
=
True
).
mul_
(
state
.
SCB
.
unsqueeze
(
1
).
mul
(
1.
/
127.0
))
grad_A
=
torch
.
matmul
(
grad_output
,
CB
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
dtype_A
)
else
:
raise
Exception
(
'State must contain either CBt or CB matrix for backward'
)
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
def
get_block_sizes
(
input_matrix
,
weight_matrix
):
input_features
=
input_matrix
.
shape
[
-
1
]
output_features
=
(
weight_matrix
.
shape
[
0
]
if
weight_matrix
.
shape
[
1
]
==
input_features
else
weight_matrix
.
shape
[
1
])
array
=
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
0
]
bsz
,
bsz2
=
1024
,
1024
for
i
,
k
in
enumerate
(
array
):
if
input_features
>
array
[
i
+
1
]:
bsz
=
k
break
for
i
,
k
in
enumerate
(
array
):
if
output_features
>
array
[
i
+
1
]:
bsz2
=
k
break
return
bsz
,
bsz2
def
matmul_fp8
(
A
:
tensor
,
B
:
tensor
,
fw_code
:
tensor
,
bw_code
:
tensor
,
out
:
tensor
=
None
,
bsz
:
int
=
-
1
,
bsz2
:
int
=
-
1
):
if
bsz
==
-
1
or
bsz2
==
-
1
:
bsz
,
bsz2
=
get_block_sizes
(
A
,
B
)
return
MatMulFP8
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
,
bsz
,
bsz2
)
def
matmul_fp8_global
(
A
:
tensor
,
B
:
tensor
,
fw_code
:
tensor
,
bw_code
:
tensor
,
out
:
tensor
=
None
,
bsz
:
int
=
-
1
,
bsz2
:
int
=
-
1
):
if
bsz
==
-
1
or
bsz2
==
-
1
:
bsz
,
bsz2
=
get_block_sizes
(
A
,
B
)
return
MatMulFP8Global
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
,
bsz
,
bsz2
)
def
matmul_fp8_mixed
(
A
:
tensor
,
B
:
tensor
,
fw_code
:
tensor
,
bw_code
:
tensor
,
out
:
tensor
=
None
,
bsz
:
int
=
-
1
,
bsz2
:
int
=
-
1
):
if
bsz
==
-
1
or
bsz2
==
-
1
:
bsz
,
bsz2
=
get_block_sizes
(
A
,
B
)
return
MatMulFP8Mixed
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
,
bsz
,
bsz2
)
def
switchback_bnb
(
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
,
bias
=
None
):
state
=
state
or
MatmulLtState
()
if
threshold
>
0.0
:
state
.
threshold
=
threshold
return
MatMul8bitMixed
.
apply
(
A
,
B
,
out
,
bias
,
state
)
bitsandbytes/triton/__init__.py
0 → 100644
View file @
e67bfccb
bitsandbytes/triton/dequantize_rowwise.py
0 → 100644
View file @
e67bfccb
import
math
import
torch
import
time
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# rowwise quantize
# TODO: autotune this better.
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
],
key
=
[
'n_elements'
]
)
@
triton
.
jit
def
_dequantize_rowwise
(
x_ptr
,
state_x
,
output_ptr
,
inv_127
,
n_elements
,
BLOCK_SIZE
:
tl
.
constexpr
,
P2
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE
arange
=
tl
.
arange
(
0
,
P2
)
offsets
=
block_start
+
arange
row_mask
=
arange
<
BLOCK_SIZE
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
row_mask
)
max_val
=
tl
.
load
(
state_x
+
pid
)
output
=
max_val
*
x
*
inv_127
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
row_mask
)
def
dequantize_rowwise
(
x
:
torch
.
Tensor
,
state_x
:
torch
.
Tensor
):
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float16
)
P2
=
int
(
2
**
(
math
.
ceil
(
math
.
log2
(
x
.
shape
[
1
]))))
assert
x
.
is_cuda
and
output
.
is_cuda
n_elements
=
output
.
numel
()
grid
=
lambda
meta
:
(
x
.
shape
[
0
],)
_dequantize_rowwise
[
grid
](
x
,
state_x
,
output
,
1.
/
127
,
n_elements
,
BLOCK_SIZE
=
x
.
shape
[
1
],
P2
=
P2
)
return
output
bitsandbytes/triton/int8_matmul_mixed_dequanitze.py
0 → 100644
View file @
e67bfccb
import
torch
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def
init_to_zero
(
name
):
return
lambda
nargs
:
nargs
[
name
].
zero_
()
def
get_configs_io_bound
():
configs
=
[]
for
num_stages
in
[
2
,
3
,
4
,
5
,
6
]:
for
block_m
in
[
16
,
32
]:
for
block_k
in
[
32
,
64
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
num_warps
=
2
if
block_n
<=
64
else
4
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
1
},
num_stages
=
num_stages
,
num_warps
=
num_warps
))
# split_k
for
split_k
in
[
2
,
4
,
8
,
16
]:
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
split_k
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
pre_hook
=
init_to_zero
(
'C'
)))
return
configs
@
triton
.
autotune
(
configs
=
[
# basic configs for compute-bound matmuls
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
]
+
get_configs_io_bound
(),
key
=
[
'M'
,
'N'
,
'K'
],
prune_configs_by
=
{
'early_config_prune'
:
early_config_prune
,
'perf_model'
:
estimate_matmul_time
,
'top_k'
:
10
},
)
@
triton
.
heuristics
({
'EVEN_K'
:
lambda
args
:
args
[
'K'
]
%
(
args
[
'BLOCK_K'
]
*
args
[
'SPLIT_K'
])
==
0
,
})
@
triton
.
jit
def
_int8_matmul_mixed_dequantize
(
A
,
B
,
C
,
bias
,
state_x_ptr
,
state_w_ptr
,
M
,
N
,
K
,
divfactor
:
tl
.
constexpr
,
has_bias
:
tl
.
constexpr
,
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
ACC_TYPE
:
tl
.
constexpr
):
# matrix multiplication
pid
=
tl
.
program_id
(
0
)
pid_z
=
tl
.
program_id
(
1
)
grid_m
=
tl
.
cdiv
(
M
,
BLOCK_M
)
grid_n
=
tl
.
cdiv
(
N
,
BLOCK_N
)
# re-order program ID for better L2 performance
width
=
GROUP_M
*
grid_n
group_id
=
pid
//
width
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
pid_n
=
(
pid
%
width
)
//
(
group_size
)
# do matrix multiplication
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rm
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rn
%
N
,
BLOCK_N
),
BLOCK_N
)
rk
=
pid_z
*
BLOCK_K
+
tl
.
arange
(
0
,
BLOCK_K
)
# pointers
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:]
*
stride_ak
)
B
=
B
+
(
rk
[:,
None
]
*
stride_bk
+
rbn
[
None
,
:]
*
stride_bn
)
# rematerialize rm and rn to save registers
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
w_factor
=
tl
.
load
(
state_w_ptr
)
x_factor
=
tl
.
load
(
state_x_ptr
+
ram
)[:,
None
]
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
int32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
*
SPLIT_K
)):
if
EVEN_K
:
a
=
tl
.
load
(
A
)
b
=
tl
.
load
(
B
)
else
:
k_remaining
=
K
-
k
*
(
BLOCK_K
*
SPLIT_K
)
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k_remaining
,
other
=
0.
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k_remaining
,
other
=
0.
)
acc
+=
tl
.
dot
(
a
,
b
)
A
+=
BLOCK_K
*
SPLIT_K
*
stride_ak
B
+=
BLOCK_K
*
SPLIT_K
*
stride_bk
acc
=
(
w_factor
*
(
x_factor
*
(
acc
*
divfactor
)))
acc
=
acc
.
to
(
C
.
dtype
.
element_ty
)
# conditionally add bias
if
has_bias
:
bias
=
tl
.
load
(
bias
+
rn
).
to
(
C
.
dtype
.
element_ty
)
acc
=
acc
+
bias
[
None
,
:]
C
=
C
+
(
rm
[:,
None
]
*
stride_cm
+
rn
[
None
,
:]
*
stride_cn
)
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
# handles write-back with reduction-splitting
if
SPLIT_K
==
1
:
tl
.
store
(
C
,
acc
,
mask
=
mask
)
else
:
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
def
int8_matmul_mixed_dequanitze
(
a
,
b
,
state_x
,
state_w
,
bias
):
device
=
a
.
device
divfactor
=
1.
/
(
127.
*
127.
)
has_bias
=
0
if
bias
is
None
else
1
# handle non-contiguous inputs if necessary
if
a
.
stride
(
0
)
>
1
and
a
.
stride
(
1
)
>
1
:
a
=
a
.
contiguous
()
if
b
.
stride
(
0
)
>
1
and
b
.
stride
(
1
)
>
1
:
b
=
b
.
contiguous
()
# checks constraints
assert
a
.
shape
[
1
]
==
b
.
shape
[
0
],
"incompatible dimensions"
M
,
K
=
a
.
shape
_
,
N
=
b
.
shape
# allocates output
c
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
float16
)
# accumulator types
ACC_TYPE
=
tl
.
float32
#if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_mixed_dequantize kernel
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
]),
META
[
'SPLIT_K'
])
_int8_matmul_mixed_dequantize
[
grid
](
a
,
b
,
c
,
bias
,
state_x
,
state_w
,
M
,
N
,
K
,
divfactor
,
has_bias
,
a
.
stride
(
0
),
a
.
stride
(
1
),
b
.
stride
(
0
),
b
.
stride
(
1
),
c
.
stride
(
0
),
c
.
stride
(
1
),
GROUP_M
=
8
,
ACC_TYPE
=
ACC_TYPE
)
return
c
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
0 → 100644
View file @
e67bfccb
import
torch
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and columnwise quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def
init_to_zero
(
name
):
return
lambda
nargs
:
nargs
[
name
].
zero_
()
def
get_configs_io_bound
():
configs
=
[]
for
num_stages
in
[
2
,
3
,
4
,
5
,
6
]:
for
block_m
in
[
16
,
32
]:
for
block_k
in
[
32
,
64
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
num_warps
=
2
if
block_n
<=
64
else
4
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
1
},
num_stages
=
num_stages
,
num_warps
=
num_warps
))
# split_k
for
split_k
in
[
2
,
4
,
8
,
16
]:
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
split_k
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
pre_hook
=
init_to_zero
(
'C'
)))
return
configs
@
triton
.
autotune
(
configs
=
[
# basic configs for compute-bound matmuls
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
]
+
get_configs_io_bound
(),
key
=
[
'M'
,
'N'
,
'K'
],
prune_configs_by
=
{
'early_config_prune'
:
early_config_prune
,
'perf_model'
:
estimate_matmul_time
,
'top_k'
:
10
},
)
@
triton
.
heuristics
({
'EVEN_K'
:
lambda
args
:
args
[
'K'
]
%
(
args
[
'BLOCK_K'
]
*
args
[
'SPLIT_K'
])
==
0
,
})
@
triton
.
jit
def
_int8_matmul_rowwise_dequantize
(
A
,
B
,
C
,
bias
,
state_x_ptr
,
state_w_ptr
,
M
,
N
,
K
,
divfactor
,
has_bias
:
tl
.
constexpr
,
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
ACC_TYPE
:
tl
.
constexpr
):
# matrix multiplication
pid
=
tl
.
program_id
(
0
)
pid_z
=
tl
.
program_id
(
1
)
grid_m
=
tl
.
cdiv
(
M
,
BLOCK_M
)
grid_n
=
tl
.
cdiv
(
N
,
BLOCK_N
)
# re-order program ID for better L2 performance
width
=
GROUP_M
*
grid_n
group_id
=
pid
//
width
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
pid_n
=
(
pid
%
width
)
//
(
group_size
)
# do matrix multiplication
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rm
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rn
%
N
,
BLOCK_N
),
BLOCK_N
)
rk
=
pid_z
*
BLOCK_K
+
tl
.
arange
(
0
,
BLOCK_K
)
# pointers
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:]
*
stride_ak
)
B
=
B
+
(
rk
[:,
None
]
*
stride_bk
+
rbn
[
None
,
:]
*
stride_bn
)
# rematerialize rm and rn to save registers
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
w_factor
=
tl
.
load
(
state_w_ptr
+
rbn
)[
None
,
:]
x_factor
=
tl
.
load
(
state_x_ptr
+
ram
)[:,
None
]
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
int32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
*
SPLIT_K
)):
if
EVEN_K
:
a
=
tl
.
load
(
A
)
b
=
tl
.
load
(
B
)
else
:
k_remaining
=
K
-
k
*
(
BLOCK_K
*
SPLIT_K
)
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k_remaining
,
other
=
0.
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k_remaining
,
other
=
0.
)
acc
+=
tl
.
dot
(
a
,
b
)
A
+=
BLOCK_K
*
SPLIT_K
*
stride_ak
B
+=
BLOCK_K
*
SPLIT_K
*
stride_bk
acc
=
(
w_factor
*
(
x_factor
*
(
acc
*
divfactor
)))
acc
=
acc
.
to
(
C
.
dtype
.
element_ty
)
if
has_bias
:
bias
=
tl
.
load
(
bias
+
rn
).
to
(
C
.
dtype
.
element_ty
)
acc
=
acc
+
bias
[
None
,
:]
C
=
C
+
(
rm
[:,
None
]
*
stride_cm
+
rn
[
None
,
:]
*
stride_cn
)
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
# handles write-back with reduction-splitting
if
SPLIT_K
==
1
:
tl
.
store
(
C
,
acc
,
mask
=
mask
)
else
:
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
def
int8_matmul_rowwise_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
divfactor
=
1.
/
(
127.
*
127.
)
has_bias
=
0
if
bias
is
None
else
1
device
=
a
.
device
# handle non-contiguous inputs if necessary
if
a
.
stride
(
0
)
>
1
and
a
.
stride
(
1
)
>
1
:
a
=
a
.
contiguous
()
if
b
.
stride
(
0
)
>
1
and
b
.
stride
(
1
)
>
1
:
b
=
b
.
contiguous
()
# checks constraints
assert
a
.
shape
[
1
]
==
b
.
shape
[
0
],
"incompatible dimensions"
M
,
K
=
a
.
shape
_
,
N
=
b
.
shape
# allocates output
c
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
float16
)
# accumulator types
ACC_TYPE
=
tl
.
float32
#if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_rowwise_dequantize kernel
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
]),
META
[
'SPLIT_K'
])
_int8_matmul_rowwise_dequantize
[
grid
](
a
,
b
,
c
,
bias
,
state_x
,
state_w
,
M
,
N
,
K
,
divfactor
,
has_bias
,
a
.
stride
(
0
),
a
.
stride
(
1
),
b
.
stride
(
0
),
b
.
stride
(
1
),
c
.
stride
(
0
),
c
.
stride
(
1
),
GROUP_M
=
8
,
ACC_TYPE
=
ACC_TYPE
)
return
c
bitsandbytes/triton/quantize_columnwise_and_transpose.py
0 → 100644
View file @
e67bfccb
import
math
import
torch
import
time
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# This kernel does fused columnwise quantization and transpose.
# TODO: autotune this better.
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_stages
=
16
),
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
16
,
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
],
key
=
[
'n_elements'
]
)
@
triton
.
jit
def
_quantize_columnwise_and_transpose
(
x_ptr
,
output_ptr
,
output_maxs
,
n_elements
,
M
:
tl
.
constexpr
,
N
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
P2
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
p2_arange
=
tl
.
arange
(
0
,
P2
)
p2_arange_mask
=
p2_arange
<
M
arange
=
p2_arange
*
N
offsets
=
block_start
+
arange
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
p2_arange_mask
)
abs_x
=
tl
.
abs
(
x
)
max_val
=
tl
.
max
(
tl
.
where
(
p2_arange_mask
,
abs_x
,
0
),
axis
=
0
)
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
x
/
max_val
))
new_start
=
pid
*
M
new_offsets
=
new_start
+
p2_arange
tl
.
store
(
output_ptr
+
new_offsets
,
output
,
mask
=
p2_arange_mask
)
tl
.
store
(
output_maxs
+
pid
,
max_val
)
def
quantize_columnwise_and_transpose
(
x
:
torch
.
Tensor
):
M
,
N
=
x
.
shape
output
=
torch
.
empty
(
N
,
M
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
output_maxs
=
torch
.
empty
(
x
.
shape
[
1
],
device
=
x
.
device
,
dtype
=
torch
.
float16
)
P2
=
int
(
2
**
(
math
.
ceil
(
math
.
log2
(
M
))))
assert
x
.
is_cuda
and
output
.
is_cuda
n_elements
=
output
.
numel
()
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_elements
,
meta
[
'BLOCK_SIZE'
]),)
_quantize_columnwise_and_transpose
[
grid
](
x
,
output
,
output_maxs
,
n_elements
,
M
,
N
,
BLOCK_SIZE
=
M
,
P2
=
P2
)
return
output
,
output_maxs
bitsandbytes/triton/quantize_global.py
0 → 100644
View file @
e67bfccb
import
math
import
torch
import
time
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# global quantize
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BLOCK_SIZE'
:
1024
,},
num_warps
=
4
),
triton
.
Config
({
'BLOCK_SIZE'
:
2048
,},
num_stages
=
1
),
],
key
=
[
'n_elements'
]
)
@
triton
.
jit
def
_quantize_global
(
x_ptr
,
absmax_inv_ptr
,
output_ptr
,
n_elements
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE
offsets
=
block_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offsets
<
n_elements
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
mask
)
absmax_inv
=
tl
.
load
(
absmax_inv_ptr
)
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
x
*
absmax_inv
))
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
mask
)
def
quantize_global
(
x
:
torch
.
Tensor
):
absmax
=
x
.
abs
().
max
().
unsqueeze
(
0
)
absmax_inv
=
1.
/
absmax
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
'cuda'
,
dtype
=
torch
.
int8
)
assert
x
.
is_cuda
and
output
.
is_cuda
n_elements
=
output
.
numel
()
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_elements
,
meta
[
'BLOCK_SIZE'
]),)
_quantize_global
[
grid
](
x
,
absmax_inv
,
output
,
n_elements
)
return
output
,
absmax
# global quantize and transpose
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'GROUP_M'
:
8
},
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'GROUP_M'
:
8
},
num_warps
=
4
),
# ...
],
key
=
[
'M'
,
'N'
]
)
@
triton
.
jit
def
_quantize_global_transpose
(
A
,
absmax_inv_ptr
,
B
,
stride_am
,
stride_an
,
stride_bn
,
stride_bm
,
M
,
N
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
0
)
grid_m
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
grid_n
=
(
N
+
BLOCK_N
-
1
)
//
BLOCK_N
width
=
GROUP_M
*
grid_n
group_id
=
pid
//
width
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
pid_n
=
(
pid
%
width
)
//
group_size
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
A
=
A
+
(
rm
[:,
None
]
*
stride_am
+
rn
[
None
,
:]
*
stride_an
)
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
a
=
tl
.
load
(
A
,
mask
=
mask
)
absmax_inv
=
tl
.
load
(
absmax_inv_ptr
)
# rematerialize to save registers
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
B
=
B
+
(
rm
[:,
None
]
*
stride_bm
+
rn
[
None
,
:]
*
stride_bn
)
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
a
*
absmax_inv
))
tl
.
store
(
B
,
output
,
mask
=
mask
)
def
quantize_global_transpose
(
input
):
absmax
=
input
.
abs
().
max
().
unsqueeze
(
0
)
absmax_inv
=
1.
/
absmax
M
,
N
=
input
.
shape
out
=
torch
.
empty
(
N
,
M
,
device
=
'cuda'
,
dtype
=
torch
.
int8
)
assert
out
.
size
(
0
)
==
N
and
out
.
size
(
1
)
==
M
assert
input
.
stride
(
0
)
==
1
or
input
.
stride
(
1
)
==
1
assert
out
.
stride
(
0
)
==
1
or
out
.
stride
(
1
)
==
1
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
]),)
_quantize_global_transpose
[
grid
](
input
,
absmax_inv
,
out
,
input
.
stride
(
0
),
input
.
stride
(
1
),
out
.
stride
(
0
),
out
.
stride
(
1
),
M
,
N
)
return
out
,
absmax
bitsandbytes/triton/quantize_rowwise.py
0 → 100644
View file @
e67bfccb
import
math
import
torch
import
time
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# rowwise quantize
# TODO: autotune this better.
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
],
key
=
[
'n_elements'
]
)
@
triton
.
jit
def
_quantize_rowwise
(
x_ptr
,
output_ptr
,
output_maxs
,
n_elements
,
BLOCK_SIZE
:
tl
.
constexpr
,
P2
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE
arange
=
tl
.
arange
(
0
,
P2
)
offsets
=
block_start
+
arange
row_mask
=
arange
<
BLOCK_SIZE
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
row_mask
)
abs_x
=
tl
.
abs
(
x
)
max_val
=
tl
.
max
(
tl
.
where
(
row_mask
,
abs_x
,
0
),
axis
=
0
)
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
x
/
max_val
))
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
row_mask
)
tl
.
store
(
output_maxs
+
pid
,
max_val
)
def
quantize_rowwise
(
x
:
torch
.
Tensor
):
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
output_maxs
=
torch
.
empty
(
x
.
shape
[
0
],
device
=
x
.
device
,
dtype
=
torch
.
float16
)
P2
=
int
(
2
**
(
math
.
ceil
(
math
.
log2
(
x
.
shape
[
1
]))))
assert
x
.
is_cuda
and
output
.
is_cuda
n_elements
=
output
.
numel
()
grid
=
lambda
meta
:
(
x
.
shape
[
0
],)
_quantize_rowwise
[
grid
](
x
,
output
,
output_maxs
,
n_elements
,
BLOCK_SIZE
=
x
.
shape
[
1
],
P2
=
P2
)
return
output
,
output_maxs
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment