Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
da524d97
Commit
da524d97
authored
Apr 08, 2023
by
Mitchell Wortsman
Browse files
mem efficient"
parent
eb6c53cf
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
119 additions
and
2 deletions
+119
-2
bitsandbytes/nn/triton_based_modules.py
bitsandbytes/nn/triton_based_modules.py
+61
-2
bitsandbytes/nn/triton_utils/v0/dequantize_rowwise.py
bitsandbytes/nn/triton_utils/v0/dequantize_rowwise.py
+58
-0
No files found.
bitsandbytes/nn/triton_based_modules.py
View file @
da524d97
...
@@ -3,6 +3,7 @@ import torch.nn as nn
...
@@ -3,6 +3,7 @@ import torch.nn as nn
import
time
import
time
from
functools
import
partial
from
functools
import
partial
from
.triton_utils.v0.dequantize_rowwise
import
dequantize_rowwise
from
.triton_utils.v0.quantize_rowwise
import
quantize_rowwise
from
.triton_utils.v0.quantize_rowwise
import
quantize_rowwise
from
.triton_utils.v0.quantize_columnwise_and_transpose
import
quantize_columnwise_and_transpose
from
.triton_utils.v0.quantize_columnwise_and_transpose
import
quantize_columnwise_and_transpose
from
.triton_utils.v0.int8_matmul_rowwise_dequantize
import
int8_matmul_rowwise_dequantize
from
.triton_utils.v0.int8_matmul_rowwise_dequantize
import
int8_matmul_rowwise_dequantize
...
@@ -98,6 +99,56 @@ class _switchback_vectorrize(torch.autograd.Function):
...
@@ -98,6 +99,56 @@ class _switchback_vectorrize(torch.autograd.Function):
return
grad_X
,
grad_W
,
grad_bias
return
grad_X
,
grad_W
,
grad_bias
class
_switchback_global_mem_efficient
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
X_3D
,
W
,
bias
):
# reshape input to [N * L, D]
X
=
X_3D
.
view
(
-
1
,
X_3D
.
size
(
-
1
))
X_3D_sz
=
X_3D
.
size
()
# rowwise quantize for X, global quantize for W
X_int8
,
state_X
=
quantize_rowwise
(
X
)
del
X
W_int8
,
state_W
=
quantize_global
(
W
)
print
(
'in mem eff backward.'
)
# save for backward.
ctx
.
save_for_backward
=
X_int8
,
state_X
,
W_int8
,
state_W
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return
int8_matmul_mixed_dequanitze
(
X_int8
,
W_int8
.
t
(),
state_X
,
state_W
,
bias
).
view
(
*
X_3D_sz
[:
-
1
],
-
1
)
@
staticmethod
def
backward
(
ctx
,
G_3D
):
# reshape input to [N_out * L, D]
G
=
G_3D
.
reshape
(
-
1
,
G_3D
.
size
(
-
1
))
G_3D_sz
=
G_3D
.
size
()
grad_X
=
grad_W
=
grad_bias
=
None
X_int8
,
state_X
,
W_int8
,
state_W
=
ctx
.
save_for_backward
if
ctx
.
needs_input_grad
[
1
]:
real_X
=
dequantize_rowwise
(
X_int8
,
state_X
)
del
X_int8
grad_W
=
torch
.
matmul
(
G
.
t
(),
real_X
.
to
(
G
.
dtype
))
del
real_X
if
ctx
.
needs_input_grad
[
2
]:
grad_bias
=
G
.
sum
(
dim
=
0
)
if
ctx
.
needs_input_grad
[
0
]:
G_int8
,
state_G
=
quantize_rowwise
(
G
)
del
G
W_int8
=
W_int8
.
t
().
contiguous
()
grad_X
=
int8_matmul_mixed_dequanitze
(
G_int8
,
W_int8
.
t
(),
state_G
,
state_W
,
None
).
view
(
*
G_3D_sz
[:
-
1
],
-
1
)
return
grad_X
,
grad_W
,
grad_bias
class
SwitchBackLinear
(
nn
.
Linear
):
class
SwitchBackLinear
(
nn
.
Linear
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -106,7 +157,8 @@ class SwitchBackLinear(nn.Linear):
...
@@ -106,7 +157,8 @@ class SwitchBackLinear(nn.Linear):
bias
:
bool
=
True
,
bias
:
bool
=
True
,
device
=
None
,
device
=
None
,
dtype
=
None
,
dtype
=
None
,
vectorize
:
bool
=
False
vectorize
:
bool
=
False
,
mem_efficient
:
bool
=
False
,
):
):
super
().
__init__
(
in_features
,
out_features
,
bias
,
device
,
dtype
)
super
().
__init__
(
in_features
,
out_features
,
bias
,
device
,
dtype
)
...
@@ -114,6 +166,12 @@ class SwitchBackLinear(nn.Linear):
...
@@ -114,6 +166,12 @@ class SwitchBackLinear(nn.Linear):
self
.
vectorize
=
vectorize
self
.
vectorize
=
vectorize
if
self
.
vectorize
:
if
self
.
vectorize
:
self
.
_fn
=
_switchback_vectorrize
self
.
_fn
=
_switchback_vectorrize
if
mem_efficient
:
print
(
'mem efficient is not supported for vectorize mode.'
)
exit
(
1
)
else
:
if
mem_efficient
:
self
.
_fn
=
_switchback_global_mem_efficient
else
:
else
:
self
.
_fn
=
_switchback_global
self
.
_fn
=
_switchback_global
...
@@ -158,6 +216,7 @@ class SwitchBackLinear(nn.Linear):
...
@@ -158,6 +216,7 @@ class SwitchBackLinear(nn.Linear):
).
view
(
*
x
.
size
()[:
-
1
],
-
1
)
).
view
(
*
x
.
size
()[:
-
1
],
-
1
)
SwitchBackLinearGlobal
=
partial
(
SwitchBackLinear
,
vectorize
=
False
)
SwitchBackLinearGlobal
=
partial
(
SwitchBackLinear
,
vectorize
=
False
)
SwitchBackLinearGlobalMemEfficient
=
partial
(
SwitchBackLinear
,
vectorize
=
False
,
mem_efficient
=
True
)
SwitchBackLinearVectorized
=
partial
(
SwitchBackLinear
,
vectorize
=
True
)
SwitchBackLinearVectorized
=
partial
(
SwitchBackLinear
,
vectorize
=
True
)
# This is just the standard linear function.
# This is just the standard linear function.
...
...
bitsandbytes/nn/triton_utils/v0/dequantize_rowwise.py
0 → 100644
View file @
da524d97
import
math
import
torch
import
time
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# rowwise quantize
# TODO: autotune this better.
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
],
key
=
[
'n_elements'
]
)
@
triton
.
jit
def
_dequantize_rowwise
(
x_ptr
,
state_x
,
output_ptr
,
inv_127
,
n_elements
,
BLOCK_SIZE
:
tl
.
constexpr
,
P2
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE
arange
=
tl
.
arange
(
0
,
P2
)
offsets
=
block_start
+
arange
row_mask
=
arange
<
BLOCK_SIZE
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
row_mask
)
max_val
=
tl
.
load
(
state_x
+
pid
)
output
=
max_val
*
x
*
inv_127
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
row_mask
)
def
dequantize_rowwise
(
x
:
torch
.
Tensor
,
state_x
:
torch
.
Tensor
):
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float16
)
P2
=
int
(
2
**
(
math
.
ceil
(
math
.
log2
(
x
.
shape
[
1
]))))
assert
x
.
is_cuda
and
output
.
is_cuda
n_elements
=
output
.
numel
()
grid
=
lambda
meta
:
(
x
.
shape
[
0
],)
_dequantize_rowwise
[
grid
](
x
,
state_x
,
output
,
1.
/
127
,
n_elements
,
BLOCK_SIZE
=
x
.
shape
[
1
],
P2
=
P2
)
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment