Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
2bb5c00b
Commit
2bb5c00b
authored
Apr 11, 2023
by
Tim Dettmers
Browse files
Added pre/post call to all lib calls. Fixes #120
parent
29ab3a6b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
0 deletions
+13
-0
bitsandbytes/functional.py
bitsandbytes/functional.py
+13
-0
No files found.
bitsandbytes/functional.py
View file @
2bb5c00b
...
@@ -770,6 +770,8 @@ def optimizer_update_32bit(
...
@@ -770,6 +770,8 @@ def optimizer_update_32bit(
f
'Optimizer not implemented:
{
optimizer_name
}
. Choices:
{
","
.
join
(
str2optimizer32bit
.
keys
())
}
'
f
'Optimizer not implemented:
{
optimizer_name
}
. Choices:
{
","
.
join
(
str2optimizer32bit
.
keys
())
}
'
)
)
prev_device
=
pre_call
(
g
.
device
)
is_on_gpu
([
g
,
p
,
state1
,
state2
,
unorm_vec
])
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
float32
:
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
float32
:
str2optimizer32bit
[
optimizer_name
][
0
](
str2optimizer32bit
[
optimizer_name
][
0
](
get_ptr
(
g
),
get_ptr
(
g
),
...
@@ -812,6 +814,7 @@ def optimizer_update_32bit(
...
@@ -812,6 +814,7 @@ def optimizer_update_32bit(
raise
ValueError
(
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
)
)
post_call
(
prev_device
)
def
optimizer_update_8bit
(
def
optimizer_update_8bit
(
...
@@ -890,6 +893,8 @@ def optimizer_update_8bit(
...
@@ -890,6 +893,8 @@ def optimizer_update_8bit(
if
max_unorm
>
0.0
:
if
max_unorm
>
0.0
:
param_norm
=
torch
.
norm
(
p
.
data
.
float
())
param_norm
=
torch
.
norm
(
p
.
data
.
float
())
prev_device
=
pre_call
(
g
.
device
)
is_on_gpu
([
g
,
p
,
state1
,
state2
,
unorm_vec
,
qmap1
,
qmap2
,
max1
,
max2
,
new_max1
,
new_max2
])
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
str2optimizer8bit
[
optimizer_name
][
0
](
str2optimizer8bit
[
optimizer_name
][
0
](
get_ptr
(
p
),
get_ptr
(
p
),
...
@@ -942,6 +947,7 @@ def optimizer_update_8bit(
...
@@ -942,6 +947,7 @@ def optimizer_update_8bit(
raise
ValueError
(
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
)
)
post_call
(
prev_device
)
def
optimizer_update_8bit_blockwise
(
def
optimizer_update_8bit_blockwise
(
...
@@ -964,6 +970,8 @@ def optimizer_update_8bit_blockwise(
...
@@ -964,6 +970,8 @@ def optimizer_update_8bit_blockwise(
skip_zeros
=
False
,
skip_zeros
=
False
,
)
->
None
:
)
->
None
:
prev_device
=
pre_call
(
g
.
device
)
is_on_gpu
([
g
,
p
,
state1
,
state2
,
qmap1
,
qmap2
,
absmax1
,
absmax2
])
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
str2optimizer8bit_blockwise
[
optimizer_name
][
0
](
str2optimizer8bit_blockwise
[
optimizer_name
][
0
](
get_ptr
(
p
),
get_ptr
(
p
),
...
@@ -1008,6 +1016,7 @@ def optimizer_update_8bit_blockwise(
...
@@ -1008,6 +1016,7 @@ def optimizer_update_8bit_blockwise(
raise
ValueError
(
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
)
)
post_call
(
prev_device
)
def
percentile_clipping
(
def
percentile_clipping
(
...
@@ -1023,6 +1032,7 @@ def percentile_clipping(
...
@@ -1023,6 +1032,7 @@ def percentile_clipping(
The current optimiation steps (number of past gradient norms).
The current optimiation steps (number of past gradient norms).
"""
"""
prev_device
=
pre_call
(
grad
.
device
)
is_on_gpu
([
grad
,
gnorm_vec
])
is_on_gpu
([
grad
,
gnorm_vec
])
if
grad
.
dtype
==
torch
.
float32
:
if
grad
.
dtype
==
torch
.
float32
:
lib
.
cpercentile_clipping_g32
(
lib
.
cpercentile_clipping_g32
(
...
@@ -1040,6 +1050,7 @@ def percentile_clipping(
...
@@ -1040,6 +1050,7 @@ def percentile_clipping(
)
)
else
:
else
:
raise
ValueError
(
f
"Gradient type
{
grad
.
dtype
}
not supported!"
)
raise
ValueError
(
f
"Gradient type
{
grad
.
dtype
}
not supported!"
)
post_call
(
prev_device
)
current_gnorm
=
torch
.
sqrt
(
gnorm_vec
[
step
%
100
])
current_gnorm
=
torch
.
sqrt
(
gnorm_vec
[
step
%
100
])
vals
,
idx
=
torch
.
sort
(
gnorm_vec
)
vals
,
idx
=
torch
.
sort
(
gnorm_vec
)
...
@@ -1796,6 +1807,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
...
@@ -1796,6 +1807,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
(
cooA
.
rows
,
B
.
shape
[
1
]),
device
=
B
.
device
,
dtype
=
cooA
.
values
.
dtype
(
cooA
.
rows
,
B
.
shape
[
1
]),
device
=
B
.
device
,
dtype
=
cooA
.
values
.
dtype
)
)
nnz
=
cooA
.
nnz
nnz
=
cooA
.
nnz
prev_device
=
pre_call
(
B
.
device
)
assert
cooA
.
rowidx
.
numel
()
==
nnz
assert
cooA
.
rowidx
.
numel
()
==
nnz
assert
cooA
.
colidx
.
numel
()
==
nnz
assert
cooA
.
colidx
.
numel
()
==
nnz
assert
cooA
.
values
.
numel
()
==
nnz
assert
cooA
.
values
.
numel
()
==
nnz
...
@@ -1872,6 +1884,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
...
@@ -1872,6 +1884,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
ccolsB
,
ccolsB
,
)
)
# else: assertion error
# else: assertion error
post_call
(
prev_device
)
return
out
return
out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment