Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
51f8bb71
Commit
51f8bb71
authored
Mar 24, 2023
by
Mitchell Wortsman
Browse files
pre-triton update
parent
75377d12
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
360 additions
and
13 deletions
+360
-13
bitsandbytes/__init__.py
bitsandbytes/__init__.py
+4
-1
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+266
-8
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+1
-1
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+89
-3
No files found.
bitsandbytes/__init__.py
View file @
51f8bb71
...
...
@@ -11,7 +11,10 @@ from .autograd._functions import (
matmul_cublas
,
mm_cublas
,
matmul_fp8
,
matmul_mixed
matmul_mixed
,
matmul_fp8_global
,
matmul_fp4
,
matmul_fp8_mixed
,
)
from
.cextension
import
COMPILED_WITH_CUDA
from
.nn
import
modules
...
...
bitsandbytes/autograd/_functions.py
View file @
51f8bb71
...
...
@@ -395,7 +395,7 @@ class MatMulFP8(torch.autograd.Function):
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
fw_code
=
None
,
bw_code
=
None
,
bsz
=
1024
):
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
fw_code
=
None
,
bw_code
=
None
,
bsz
=
1024
,
bsz2
=
1024
):
# default of pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
if
prod
(
A
.
shape
)
==
0
:
...
...
@@ -425,6 +425,7 @@ class MatMulFP8(torch.autograd.Function):
ctx
.
fw_code
=
fw_code
ctx
.
bw_code
=
bw_code
ctx
.
bsz
=
bsz
ctx
.
bsz2
=
bsz2
ctx
.
dtype_A
,
ctx
.
dtype_B
=
A
.
dtype
,
B
.
dtype
if
any
(
ctx
.
needs_input_grad
[:
2
]):
...
...
@@ -440,14 +441,13 @@ class MatMulFP8(torch.autograd.Function):
if
ctx
.
is_empty
:
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
None
,
None
,
None
req_gradA
,
req_gradB
,
_
,
_
,
_
,
_
=
ctx
.
needs_input_grad
req_gradA
,
req_gradB
,
_
,
_
,
_
,
_
,
_
=
ctx
.
needs_input_grad
A
,
B
=
ctx
.
tensors
grad_A
,
grad_B
=
None
,
None
# TODO: Fix blocksize to be output_dim
cgrad_out
,
state
=
F
.
quantize_blockwise
(
grad_output
,
code
=
ctx
.
bw_code
,
blocksize
=
ctx
.
bsz
)
fp8out
=
F
.
dequantize_blockwise
(
cgrad_out
,
state
,
blocksize
=
ctx
.
bsz
).
to
(
grad_output
.
dtype
)
cgrad_out
,
state
=
F
.
quantize_blockwise
(
grad_output
,
code
=
ctx
.
bw_code
,
blocksize
=
ctx
.
bsz2
)
fp8out
=
F
.
dequantize_blockwise
(
cgrad_out
,
state
,
blocksize
=
ctx
.
bsz2
).
to
(
grad_output
.
dtype
)
cgrad_output_2
,
state_2
=
F
.
quantize
(
grad_output
.
float
(),
code
=
ctx
.
bw_code
)
fp8out_2
=
F
.
dequantize
(
cgrad_output_2
,
state_2
).
to
(
grad_output
.
dtype
)
...
...
@@ -467,7 +467,249 @@ class MatMulFP8(torch.autograd.Function):
fp8At
=
F
.
dequantize
(
cA
,
state
).
to
(
A
.
dtype
)
grad_B
=
torch
.
matmul
(
fp8At
.
to
(
fp8out_2
.
dtype
),
fp8out_2
).
to
(
B
.
dtype
)
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
class
MatMulFP8Mixed
(
torch
.
autograd
.
Function
):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
fw_code
=
None
,
bw_code
=
None
,
bsz
=
1024
,
bsz2
=
1024
):
# default of pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
if
prod
(
A
.
shape
)
==
0
:
ctx
.
is_empty
=
True
ctx
.
A
=
A
ctx
.
B
=
B
B_shape
=
B
.
shape
if
A
.
shape
[
-
1
]
==
B_shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[:
1
],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
# 1. Dequantize
# 2. MatmulnN
cA
,
state
=
F
.
quantize_blockwise
(
A
,
code
=
fw_code
,
blocksize
=
bsz
)
fp8A
=
F
.
dequantize_blockwise
(
cA
,
state
,
blocksize
=
bsz
).
to
(
A
.
dtype
)
cB
,
state
=
F
.
quantize
(
B
.
float
(),
code
=
fw_code
)
fp8B
=
F
.
dequantize
(
cB
,
state
).
to
(
B
.
dtype
)
output
=
torch
.
matmul
(
fp8A
,
fp8B
)
# output is half
# 3. Save state
ctx
.
fw_code
=
fw_code
ctx
.
bw_code
=
bw_code
ctx
.
bsz
=
bsz
ctx
.
bsz2
=
bsz2
ctx
.
dtype_A
,
ctx
.
dtype_B
=
A
.
dtype
,
B
.
dtype
if
any
(
ctx
.
needs_input_grad
[:
2
]):
# NOTE: we send back A, and re-quant.
ctx
.
tensors
=
(
A
,
fp8B
)
else
:
ctx
.
tensors
=
(
None
,
None
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
None
,
None
,
None
req_gradA
,
req_gradB
,
_
,
_
,
_
,
_
,
_
=
ctx
.
needs_input_grad
A
,
B
=
ctx
.
tensors
grad_A
,
grad_B
=
None
,
None
# TODO: Fix blocksize to be output_dim
cgrad_out
,
state
=
F
.
quantize_blockwise
(
grad_output
,
code
=
ctx
.
bw_code
,
blocksize
=
ctx
.
bsz2
)
fp8out
=
F
.
dequantize_blockwise
(
cgrad_out
,
state
,
blocksize
=
ctx
.
bsz2
).
to
(
grad_output
.
dtype
)
# cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
# fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
# not supported by PyTorch. TODO: create work-around
if
req_gradA
:
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
().
to
(
fp8out
.
dtype
)).
to
(
A
.
dtype
)
if
req_gradB
:
At
=
A
.
transpose
(
2
,
1
).
contiguous
()
# cA, state = F.quantize(At.float(), code=ctx.fw_code)
# fp8At = F.dequantize(cA, state).to(A.dtype)
grad_B
=
torch
.
matmul
(
At
.
to
(
grad_output
.
dtype
),
grad_output
).
to
(
B
.
dtype
)
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
class
MatMulFP4
(
torch
.
autograd
.
Function
):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
fw_code
=
None
,
bw_code
=
None
,
bsz
=
1024
,
bsz2
=
1024
):
# default of pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
if
prod
(
A
.
shape
)
==
0
:
ctx
.
is_empty
=
True
ctx
.
A
=
A
ctx
.
B
=
B
B_shape
=
B
.
shape
if
A
.
shape
[
-
1
]
==
B_shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[:
1
],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
# 1. Dequantize
# 2. MatmulnN
cA
,
state
=
F
.
quantize_blockwise
(
A
,
code
=
fw_code
,
blocksize
=
bsz
)
fp8A
=
F
.
dequantize_blockwise
(
cA
,
state
,
blocksize
=
bsz
).
to
(
A
.
dtype
)
cB
,
state
=
F
.
quantize
(
B
.
float
(),
code
=
fw_code
)
fp8B
=
F
.
dequantize
(
cB
,
state
).
to
(
B
.
dtype
)
output
=
torch
.
matmul
(
fp8A
,
fp8B
)
# output is half
# 3. Save state
ctx
.
fw_code
=
fw_code
ctx
.
bw_code
=
bw_code
ctx
.
bsz
=
bsz
ctx
.
bsz2
=
bsz2
ctx
.
dtype_A
,
ctx
.
dtype_B
=
A
.
dtype
,
B
.
dtype
if
any
(
ctx
.
needs_input_grad
[:
2
]):
# NOTE: we send back A, and re-quant.
ctx
.
tensors
=
(
A
,
fp8B
)
else
:
ctx
.
tensors
=
(
None
,
None
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
None
,
None
,
None
req_gradA
,
req_gradB
,
_
,
_
,
_
,
_
,
_
=
ctx
.
needs_input_grad
A
,
B
=
ctx
.
tensors
grad_A
,
grad_B
=
None
,
None
# TODO: Fix blocksize to be output_dim
cgrad_out
,
state
=
F
.
quantize_blockwise
(
grad_output
,
code
=
ctx
.
bw_code
,
blocksize
=
ctx
.
bsz2
)
fp8out
=
F
.
dequantize_blockwise
(
cgrad_out
,
state
,
blocksize
=
ctx
.
bsz2
).
to
(
grad_output
.
dtype
)
cgrad_output_2
,
state_2
=
F
.
quantize
(
grad_output
.
float
(),
code
=
ctx
.
bw_code
)
fp8out_2
=
F
.
dequantize
(
cgrad_output_2
,
state_2
).
to
(
grad_output
.
dtype
)
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
# not supported by PyTorch. TODO: create work-around
if
req_gradA
:
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
().
to
(
fp8out
.
dtype
)).
to
(
A
.
dtype
)
if
req_gradB
:
At
=
A
.
transpose
(
2
,
1
).
contiguous
()
cA
,
state
=
F
.
quantize
(
At
.
float
(),
code
=
ctx
.
bw_code
)
fp8At
=
F
.
dequantize
(
cA
,
state
).
to
(
A
.
dtype
)
grad_B
=
torch
.
matmul
(
fp8At
.
to
(
fp8out_2
.
dtype
),
fp8out_2
).
to
(
B
.
dtype
)
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
class
MatMulFP8Global
(
torch
.
autograd
.
Function
):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
fw_code
=
None
,
bw_code
=
None
,
bsz
=
1024
,
bsz2
=
1024
):
# default of pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
if
prod
(
A
.
shape
)
==
0
:
ctx
.
is_empty
=
True
ctx
.
A
=
A
ctx
.
B
=
B
B_shape
=
B
.
shape
if
A
.
shape
[
-
1
]
==
B_shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[:
1
],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
# 1. Dequantize
# 2. MatmulnN
cA
,
state
=
F
.
quantize
(
A
.
float
(),
code
=
fw_code
)
fp8A
=
F
.
dequantize
(
cA
,
state
).
to
(
A
.
dtype
)
cB
,
state
=
F
.
quantize
(
B
.
float
(),
code
=
fw_code
)
fp8B
=
F
.
dequantize
(
cB
,
state
).
to
(
B
.
dtype
)
output
=
torch
.
matmul
(
fp8A
,
fp8B
)
# output is half
# 3. Save state
ctx
.
fw_code
=
fw_code
ctx
.
bw_code
=
bw_code
ctx
.
bsz
=
bsz
ctx
.
bsz2
=
bsz2
ctx
.
dtype_A
,
ctx
.
dtype_B
=
A
.
dtype
,
B
.
dtype
if
any
(
ctx
.
needs_input_grad
[:
2
]):
# NOTE: we send back A, and re-quant.
ctx
.
tensors
=
(
A
,
fp8B
)
else
:
ctx
.
tensors
=
(
None
,
None
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
None
,
None
,
None
req_gradA
,
req_gradB
,
_
,
_
,
_
,
_
,
_
=
ctx
.
needs_input_grad
A
,
B
=
ctx
.
tensors
grad_A
,
grad_B
=
None
,
None
# TODO: Fix blocksize to be output_dim
cgrad_out
,
state
=
F
.
quantize
(
grad_output
.
float
(),
code
=
ctx
.
bw_code
)
fp8out
=
F
.
dequantize
(
cgrad_out
,
state
).
to
(
grad_output
.
dtype
)
# cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
# fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
# not supported by PyTorch. TODO: create work-around
if
req_gradA
:
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
().
to
(
fp8out
.
dtype
)).
to
(
A
.
dtype
)
if
req_gradB
:
At
=
A
.
transpose
(
2
,
1
).
contiguous
()
cA
,
state
=
F
.
quantize
(
At
.
float
(),
code
=
ctx
.
fw_code
)
fp8At
=
F
.
dequantize
(
cA
,
state
).
to
(
A
.
dtype
)
grad_B
=
torch
.
matmul
(
fp8At
.
to
(
fp8out
.
dtype
),
fp8out
).
to
(
B
.
dtype
)
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
class
MatMul8bitMixed
(
torch
.
autograd
.
Function
):
...
...
@@ -520,12 +762,14 @@ class MatMul8bitMixed(torch.autograd.Function):
# we also need to convert it to the turing/ampere format
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
else
:
#print('A shape', A.shape)
if
not
state
.
has_fp16_weights
and
state
.
CxB
is
None
:
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
subA
=
None
# 2. Quantize B
if
state
.
has_fp16_weights
:
#print('B shape', B.shape)
has_grad
=
True
if
(
getattr
(
B
,
"grad"
,
None
)
is
not
None
)
else
False
is_transposed
=
not
B
.
is_contiguous
()
and
B
.
shape
[
0
]
==
B
.
stride
(
1
)
if
is_transposed
:
...
...
@@ -633,6 +877,8 @@ class MatMul8bitMixed(torch.autograd.Function):
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
.
to
(
torch
.
float16
))
if
req_gradB
:
# print('back A shape', A.shape)
# print('grad output t shape', grad_output.t().shape)
grad_B
=
torch
.
matmul
(
grad_output
.
t
(),
A
)
if
req_gradA
:
...
...
@@ -642,6 +888,8 @@ class MatMul8bitMixed(torch.autograd.Function):
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
# print('back B shape', state.CxBt.shape)
# print('back grad shape', C32grad.shape)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
dtype_A
)
...
...
@@ -668,8 +916,18 @@ def matmul(
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
bias
,
state
)
def
matmul_fp8
(
A
:
tensor
,
B
:
tensor
,
fw_code
:
tensor
,
bw_code
:
tensor
,
out
:
tensor
=
None
,
bsz
:
int
=
-
1
):
return
MatMulFP8
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
,
bsz
)
def
matmul_fp8
(
A
:
tensor
,
B
:
tensor
,
fw_code
:
tensor
,
bw_code
:
tensor
,
out
:
tensor
=
None
,
bsz
:
int
=
-
1
,
bsz2
:
int
=
-
1
):
return
MatMulFP8
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
,
bsz
,
bsz2
)
def
matmul_fp8_global
(
A
:
tensor
,
B
:
tensor
,
fw_code
:
tensor
,
bw_code
:
tensor
,
out
:
tensor
=
None
,
bsz
:
int
=
-
1
,
bsz2
:
int
=
-
1
):
return
MatMulFP8Global
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
,
bsz
,
bsz2
)
def
matmul_fp8_mixed
(
A
:
tensor
,
B
:
tensor
,
fw_code
:
tensor
,
bw_code
:
tensor
,
out
:
tensor
=
None
,
bsz
:
int
=
-
1
,
bsz2
:
int
=
-
1
):
return
MatMulFP8Mixed
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
,
bsz
,
bsz2
)
def
matmul_fp4
(
A
:
tensor
,
B
:
tensor
,
fw_code
:
tensor
,
bw_code
:
tensor
,
out
:
tensor
=
None
,
bsz
:
int
=
-
1
,
bsz2
:
int
=
-
1
):
return
MatMulFP4
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
,
bsz
,
bsz2
)
def
matmul_mixed
(
...
...
bitsandbytes/nn/__init__.py
View file @
51f8bb71
...
...
@@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
OutlierAwareLinear
,
Fake4bitLinear
,
LinearFP8
,
LinearInt8
,
Linear8bitLtThresh
,
LinearInt8Cast
,
Linear8bitLt2
,
Linear8bitLtMixed
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
OutlierAwareLinear
,
Fake4bitLinear
,
LinearFP8
,
LinearInt8
,
Linear8bitLtThresh
,
LinearInt8Cast
,
Linear8bitLt2
,
Linear8bitLtMixed
,
LinearFP8Global
,
LinearFP4
,
LinearFP8Mixed
bitsandbytes/nn/modules.py
View file @
51f8bb71
...
...
@@ -498,14 +498,69 @@ class LinearFP8(nn.Linear):
if
input_features
>
array
[
i
+
1
]:
self
.
bsz
=
k
break
print
(
'block size is'
,
self
.
bsz
)
for
i
,
k
in
enumerate
(
array
):
if
output_features
>
array
[
i
+
1
]:
self
.
bsz2
=
k
break
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
fw_code
is
None
:
self
.
bw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
fw_code
,
bw_code
=
self
.
bw_code
,
bsz
=
self
.
bsz
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
fw_code
,
bw_code
=
self
.
bw_code
,
bsz
=
self
.
bsz
,
bsz2
=
self
.
bsz2
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
return
out
class
LinearFP8Mixed
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
bw_code
=
None
self
.
fw_code
=
None
array
=
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
0
]
for
i
,
k
in
enumerate
(
array
):
if
input_features
>
array
[
i
+
1
]:
self
.
bsz
=
k
break
for
i
,
k
in
enumerate
(
array
):
if
output_features
>
array
[
i
+
1
]:
self
.
bsz2
=
k
break
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
fw_code
is
None
:
self
.
bw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp8_mixed
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
fw_code
,
bw_code
=
self
.
bw_code
,
bsz
=
self
.
bsz
,
bsz2
=
self
.
bsz2
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
return
out
class
LinearFP8Global
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
bw_code
=
None
self
.
fw_code
=
None
array
=
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
0
]
for
i
,
k
in
enumerate
(
array
):
if
input_features
>
array
[
i
+
1
]:
self
.
bsz
=
k
break
for
i
,
k
in
enumerate
(
array
):
if
output_features
>
array
[
i
+
1
]:
self
.
bsz2
=
k
break
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
fw_code
is
None
:
self
.
bw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp8_global
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
fw_code
,
bw_code
=
self
.
bw_code
,
bsz
=
self
.
bsz
,
bsz2
=
self
.
bsz2
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
...
...
@@ -520,12 +575,16 @@ class LinearInt8(nn.Linear):
if
input_features
>
array
[
i
+
1
]:
self
.
bsz
=
k
break
for
i
,
k
in
enumerate
(
array
):
if
output_features
>
array
[
i
+
1
]:
self
.
bsz2
=
k
break
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
code
is
None
:
self
.
code
=
bnb
.
functional
.
create_linear_map
(
True
,
8
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
code
,
bw_code
=
self
.
code
,
bsz
=
self
.
bsz
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
code
,
bw_code
=
self
.
code
,
bsz
=
self
.
bsz
,
bsz2
=
self
.
bsz2
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
...
...
@@ -553,3 +612,30 @@ class LinearInt8Cast(nn.Linear):
return
out
class
LinearFP4
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
bw_code
=
None
self
.
fw_code
=
None
array
=
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
0
]
for
i
,
k
in
enumerate
(
array
):
if
input_features
>
array
[
i
+
1
]:
self
.
bsz
=
k
break
for
i
,
k
in
enumerate
(
array
):
if
output_features
>
array
[
i
+
1
]:
self
.
bsz2
=
k
break
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
fw_code
is
None
:
#self.bw_code = bnb.functional.create_fp8_map(True, 3, 0, 4).to(x.device)
self
.
bw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
3
,
0
,
4
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp4
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
fw_code
,
bw_code
=
self
.
bw_code
,
bsz
=
self
.
bsz
,
bsz2
=
self
.
bsz2
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
return
out
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment