Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
a13a522c
Commit
a13a522c
authored
Mar 31, 2023
by
Tim Dettmers
Browse files
Added first triton test.
parent
b373034e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
19 additions
and
98 deletions
+19
-98
bitsandbytes/nn/triton_based_modules.py
bitsandbytes/nn/triton_based_modules.py
+3
-82
tests/triton_tests/attn_decomp.py
tests/triton_tests/attn_decomp.py
+3
-3
tests/triton_tests/full_matrix_decomp.py
tests/triton_tests/full_matrix_decomp.py
+2
-2
tests/triton_tests/mlp.py
tests/triton_tests/mlp.py
+3
-3
tests/triton_tests/mlp_decomp_autocast.py
tests/triton_tests/mlp_decomp_autocast.py
+4
-4
tests/triton_tests/mlp_decomp_autocast_ln.py
tests/triton_tests/mlp_decomp_autocast_ln.py
+4
-4
No files found.
bitsandbytes/nn/triton_based_modules.py
View file @
a13a522c
...
@@ -133,7 +133,7 @@ class SwitchBackGlobalLinear(nn.Linear):
...
@@ -133,7 +133,7 @@ class SwitchBackGlobalLinear(nn.Linear):
class
LinearFunction
(
torch
.
autograd
.
Function
):
class
Standard
LinearFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
=
None
):
def
forward
(
ctx
,
input
,
weight
,
bias
=
None
):
X
=
input
.
view
(
-
1
,
input
.
size
(
-
1
))
X
=
input
.
view
(
-
1
,
input
.
size
(
-
1
))
...
@@ -161,87 +161,8 @@ class LinearFunction(torch.autograd.Function):
...
@@ -161,87 +161,8 @@ class LinearFunction(torch.autograd.Function):
return
grad_input
,
grad_weight
,
grad_bias
return
grad_input
,
grad_weight
,
grad_bias
class
My
Linear
(
nn
.
Linear
):
class
Standard
Linear
(
nn
.
Linear
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
LinearFunction
.
apply
(
x
,
self
.
weight
,
self
.
bias
)
return
Standard
LinearFunction
.
apply
(
x
,
self
.
weight
,
self
.
bias
)
class
_switchback_mlp
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
X_3D
,
W1
,
B1
,
W2
,
B2
):
X1
=
X_3D
.
view
(
-
1
,
X_3D
.
size
(
-
1
))
X1_int8
,
state_X1
=
quantize_rowwise_nogroup
(
X1
)
W1_int8
,
state_W1
=
quantize_global
(
W1
)
X2_pre
=
int8_matmul_mixed_dequanitze_bias
(
X1_int8
,
W1_int8
.
t
(),
state_X1
,
state_W1
,
B1
)
# X2_v1 = torch.nn.functional.gelu(X2)
# X2_int8, state_X2, = quantize_rowwise_nogroup(X2_v1)
X2_int8
,
state_X2
,
X2
=
quantize_rowwise_nogroup_gelu
(
X2_pre
)
W2_int8
,
state_W2
=
quantize_global
(
W2
)
out
=
int8_matmul_mixed_dequanitze_bias
(
X2_int8
,
W2_int8
.
t
(),
state_X2
,
state_W2
,
B2
)
ctx
.
save_for_backward
=
X1
,
W1
,
X2
,
X2_pre
,
W2
return
out
.
view
(
*
X_3D
.
size
()[:
-
1
],
-
1
)
@
staticmethod
def
backward
(
ctx
,
G_3D
):
G2
=
G_3D
.
reshape
(
-
1
,
G_3D
.
size
(
-
1
))
grad_X1
=
grad_W1
=
grad_B1
=
grad_W2
=
grad_B2
=
None
X1
,
W1
,
X2
,
X2_pre
,
W2
=
ctx
.
save_for_backward
G2_int8
,
state_G2
=
quantize_rowwise_nogroup
(
G2
)
W2_int8
,
state_W2
=
quantize_global_transpose
(
W2
)
G1
=
int8_matmul_mixed_dequanitze
(
G2_int8
,
W2_int8
.
t
(),
state_G2
,
state_W2
).
view
(
*
G_3D
.
size
()[:
-
1
],
-
1
)
grad_W2
=
torch
.
matmul
(
G2
.
t
(),
X2
.
to
(
G2
.
dtype
))
grad_B2
=
G2
.
sum
(
dim
=
0
)
G1_int8
,
state_G1
,
G1
=
quantize_rowwise_nogroup_back_gelu
(
G1
,
X2_pre
)
if
ctx
.
needs_input_grad
[
0
]:
W1_int8
,
state_W1
=
quantize_global_transpose
(
W1
)
grad_X1
=
int8_matmul_mixed_dequanitze
(
G1_int8
,
W1_int8
.
t
(),
state_G1
,
state_W1
).
view
(
*
G_3D
.
size
()[:
-
1
],
-
1
)
if
ctx
.
needs_input_grad
[
1
]:
grad_W1
=
torch
.
matmul
(
G1
.
t
(),
X1
.
to
(
G1
.
dtype
))
if
ctx
.
needs_input_grad
[
2
]:
grad_B1
=
G1
.
sum
(
dim
=
0
)
return
grad_X1
,
grad_W1
,
grad_B1
,
grad_W2
,
grad_B2
class
SwitchBackGlobalMLP
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_hidden
):
super
().
__init__
()
self
.
linear1
=
nn
.
Linear
(
dim_in
,
dim_hidden
)
self
.
linear2
=
nn
.
Linear
(
dim_hidden
,
dim_in
)
def
forward
(
self
,
x
):
return
_switchback_mlp
.
apply
(
x
,
self
.
linear1
.
weight
,
self
.
linear1
.
bias
,
self
.
linear2
.
weight
,
self
.
linear2
.
bias
)
\ No newline at end of file
tests/triton_tests/attn_decomp.py
View file @
a13a522c
import
torch
import
torch
import
json
import
json
from
bitsandbytes.nn.triton_based_modules
import
SwitchBackGlobalMLP
,
SwitchBackGlobalLinear
,
My
Linear
from
bitsandbytes.nn.triton_based_modules
import
SwitchBackGlobalMLP
,
SwitchBackGlobalLinear
,
Standard
Linear
import
time
import
time
# class AttentionOld(torch.nn.Module):
# class AttentionOld(torch.nn.Module):
...
@@ -116,7 +116,7 @@ if __name__ == '__main__':
...
@@ -116,7 +116,7 @@ if __name__ == '__main__':
va
=
torch
.
randn
(
batch
//
256
,
256
,
dim
).
cuda
().
requires_grad_
(
True
)
va
=
torch
.
randn
(
batch
//
256
,
256
,
dim
).
cuda
().
requires_grad_
(
True
)
standard
=
Attention
(
dim
).
cuda
()
standard
=
Attention
(
dim
).
cuda
()
my_standard
=
Attention
(
dim
,
linear_module
=
My
Linear
).
cuda
()
my_standard
=
Attention
(
dim
,
linear_module
=
Standard
Linear
).
cuda
()
sb
=
Attention
(
dim
,
linear_module
=
SwitchBackGlobalLinear
).
cuda
()
sb
=
Attention
(
dim
,
linear_module
=
SwitchBackGlobalLinear
).
cuda
()
standard_compiled
=
torch
.
compile
(
standard
)
standard_compiled
=
torch
.
compile
(
standard
)
ln_model
=
torch
.
nn
.
Sequential
(
ln_model
=
torch
.
nn
.
Sequential
(
...
@@ -360,4 +360,4 @@ if __name__ == '__main__':
...
@@ -360,4 +360,4 @@ if __name__ == '__main__':
# import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
\ No newline at end of file
tests/triton_tests/full_matrix_decomp.py
View file @
a13a522c
...
@@ -4,7 +4,7 @@ import time
...
@@ -4,7 +4,7 @@ import time
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
bitsandbytes.nn
as
bnn
import
bitsandbytes.nn
as
bnn
from
bitsandbytes.nn.triton_based_modules
import
SwitchBackLinear
,
SwitchBackGlobalLinear
,
My
Linear
from
bitsandbytes.nn.triton_based_modules
import
SwitchBackLinear
,
SwitchBackGlobalLinear
,
Standard
Linear
from
bitsandbytes.nn.triton_utils.v0.quantize_rowwise_nogroup
import
quantize_rowwise_nogroup
from
bitsandbytes.nn.triton_utils.v0.quantize_rowwise_nogroup
import
quantize_rowwise_nogroup
from
bitsandbytes.nn.triton_utils.v0.quantize_columnwise_nogroup_transpose
import
quantize_columnwise_nogroup_transpose
from
bitsandbytes.nn.triton_utils.v0.quantize_columnwise_nogroup_transpose
import
quantize_columnwise_nogroup_transpose
...
@@ -350,4 +350,4 @@ if __name__ == '__main__':
...
@@ -350,4 +350,4 @@ if __name__ == '__main__':
with
open
(
"tests/triton_tests/info.jsonl"
,
"a"
)
as
file
:
with
open
(
"tests/triton_tests/info.jsonl"
,
"a"
)
as
file
:
file
.
write
(
info_json
+
"
\n
"
)
file
.
write
(
info_json
+
"
\n
"
)
\ No newline at end of file
tests/triton_tests/mlp.py
View file @
a13a522c
...
@@ -3,7 +3,7 @@ import time
...
@@ -3,7 +3,7 @@ import time
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
bitsandbytes.nn
as
bnn
import
bitsandbytes.nn
as
bnn
from
bitsandbytes.nn.triton_based_modules
import
SwitchBackLinear
,
SwitchBackGlobalLinear
,
My
Linear
from
bitsandbytes.nn.triton_based_modules
import
SwitchBackLinear
,
SwitchBackGlobalLinear
,
Standard
Linear
def
construct_model
(
dim
,
layers
,
module
):
def
construct_model
(
dim
,
layers
,
module
):
modules
=
[]
modules
=
[]
...
@@ -41,7 +41,7 @@ if __name__ == '__main__':
...
@@ -41,7 +41,7 @@ if __name__ == '__main__':
# construct models
# construct models
standard
=
construct_model
(
dim
,
layers
,
nn
.
Linear
).
half
()
standard
=
construct_model
(
dim
,
layers
,
nn
.
Linear
).
half
()
my_standard
=
construct_model
(
dim
,
layers
,
My
Linear
).
half
()
my_standard
=
construct_model
(
dim
,
layers
,
Standard
Linear
).
half
()
switchback
=
construct_model
(
dim
,
layers
,
SwitchBackLinear
).
half
()
switchback
=
construct_model
(
dim
,
layers
,
SwitchBackLinear
).
half
()
switchback_global
=
construct_model
(
dim
,
layers
,
SwitchBackGlobalLinear
).
half
()
switchback_global
=
construct_model
(
dim
,
layers
,
SwitchBackGlobalLinear
).
half
()
#bnb_8bitmixed = construct_model(dim, layers, bnn.Linear8bitLt)
#bnb_8bitmixed = construct_model(dim, layers, bnn.Linear8bitLt)
...
@@ -61,4 +61,4 @@ if __name__ == '__main__':
...
@@ -61,4 +61,4 @@ if __name__ == '__main__':
\ No newline at end of file
tests/triton_tests/mlp_decomp_autocast.py
View file @
a13a522c
import
torch
import
torch
import
json
import
json
from
bitsandbytes.nn.triton_based_modules
import
SwitchBackGlobalMLP
,
SwitchBackGlobalLinear
,
My
Linear
from
bitsandbytes.nn.triton_based_modules
import
SwitchBackGlobalMLP
,
SwitchBackGlobalLinear
,
Standard
Linear
import
time
import
time
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -26,9 +26,9 @@ if __name__ == '__main__':
...
@@ -26,9 +26,9 @@ if __name__ == '__main__':
).
cuda
()
).
cuda
()
my_standard
=
torch
.
nn
.
Sequential
(
my_standard
=
torch
.
nn
.
Sequential
(
My
Linear
(
dim
,
4
*
dim
),
Standard
Linear
(
dim
,
4
*
dim
),
torch
.
nn
.
GELU
(),
torch
.
nn
.
GELU
(),
My
Linear
(
4
*
dim
,
dim
),
Standard
Linear
(
4
*
dim
,
dim
),
).
cuda
()
).
cuda
()
fused_mlp
=
SwitchBackGlobalMLP
(
dim
,
4
*
dim
).
cuda
()
fused_mlp
=
SwitchBackGlobalMLP
(
dim
,
4
*
dim
).
cuda
()
...
@@ -163,4 +163,4 @@ if __name__ == '__main__':
...
@@ -163,4 +163,4 @@ if __name__ == '__main__':
# import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
\ No newline at end of file
tests/triton_tests/mlp_decomp_autocast_ln.py
View file @
a13a522c
import
torch
import
torch
import
json
import
json
from
bitsandbytes.nn.triton_based_modules
import
SwitchBackGlobalMLP
,
SwitchBackGlobalLinear
,
My
Linear
from
bitsandbytes.nn.triton_based_modules
import
SwitchBackGlobalMLP
,
SwitchBackGlobalLinear
,
Standard
Linear
import
time
import
time
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -24,9 +24,9 @@ if __name__ == '__main__':
...
@@ -24,9 +24,9 @@ if __name__ == '__main__':
my_standard
=
torch
.
nn
.
Sequential
(
my_standard
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
LayerNorm
(
dim
),
torch
.
nn
.
LayerNorm
(
dim
),
My
Linear
(
dim
,
4
*
dim
),
Standard
Linear
(
dim
,
4
*
dim
),
torch
.
nn
.
GELU
(),
torch
.
nn
.
GELU
(),
My
Linear
(
4
*
dim
,
dim
),
Standard
Linear
(
4
*
dim
,
dim
),
).
cuda
()
).
cuda
()
fused_mlp
=
SwitchBackGlobalMLP
(
dim
,
4
*
dim
).
cuda
()
fused_mlp
=
SwitchBackGlobalMLP
(
dim
,
4
*
dim
).
cuda
()
...
@@ -162,4 +162,4 @@ if __name__ == '__main__':
...
@@ -162,4 +162,4 @@ if __name__ == '__main__':
# import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment