Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
41a9c708
Commit
41a9c708
authored
May 06, 2023
by
Tim Dettmers
Browse files
Changed prefetching.
parent
44d68ff2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
2 deletions
+14
-2
bitsandbytes/functional.py
bitsandbytes/functional.py
+4
-1
bitsandbytes/optim/optimizer.py
bitsandbytes/optim/optimizer.py
+10
-1
No files found.
bitsandbytes/functional.py
View file @
41a9c708
...
@@ -100,7 +100,10 @@ class GlobalPageManager:
...
@@ -100,7 +100,10 @@ class GlobalPageManager:
return
cls
.
_instance
return
cls
.
_instance
def
prefetch_all
(
self
,
to_cpu
=
False
):
def
prefetch_all
(
self
,
to_cpu
=
False
):
for
t
in
self
.
paged_tensors
:
# assume the first added, will be hte
# ones that are used first, so swap them in last
# in the case they are evicted again
for
t
in
self
.
paged_tensors
[::
-
1
]:
prefetch_tensor
(
t
,
to_cpu
)
prefetch_tensor
(
t
,
to_cpu
)
...
...
bitsandbytes/optim/optimizer.py
View file @
41a9c708
...
@@ -256,7 +256,7 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -256,7 +256,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self
.
to_gpu
()
# needed for fairseq pure fp16 training
self
.
to_gpu
()
# needed for fairseq pure fp16 training
self
.
initialized
=
True
self
.
initialized
=
True
if
self
.
is_paged
:
self
.
page_mng
.
prefetch_all
()
#
if self.is_paged: self.page_mng.prefetch_all()
for
gindex
,
group
in
enumerate
(
self
.
param_groups
):
for
gindex
,
group
in
enumerate
(
self
.
param_groups
):
for
pindex
,
p
in
enumerate
(
group
[
"params"
]):
for
pindex
,
p
in
enumerate
(
group
[
"params"
]):
if
p
.
grad
is
None
:
if
p
.
grad
is
None
:
...
@@ -265,7 +265,9 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -265,7 +265,9 @@ class Optimizer8bit(torch.optim.Optimizer):
if
len
(
state
)
==
0
:
if
len
(
state
)
==
0
:
self
.
init_state
(
group
,
p
,
gindex
,
pindex
)
self
.
init_state
(
group
,
p
,
gindex
,
pindex
)
self
.
prefetch_state
(
p
)
self
.
update_step
(
group
,
p
,
gindex
,
pindex
)
self
.
update_step
(
group
,
p
,
gindex
,
pindex
)
torch
.
cuda
.
synchronize
()
if
self
.
is_paged
:
if
self
.
is_paged
:
# all paged operation are asynchronous, we need
# all paged operation are asynchronous, we need
# to sync to make sure all tensors are in the right state
# to sync to make sure all tensors are in the right state
...
@@ -309,6 +311,13 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -309,6 +311,13 @@ class Optimizer8bit(torch.optim.Optimizer):
self
.
page_mng
.
paged_tensors
.
append
(
buff
)
self
.
page_mng
.
paged_tensors
.
append
(
buff
)
return
buff
return
buff
def
prefetch_state
(
self
,
p
):
if
self
.
is_paged
:
state
=
self
.
state
[
p
]
F
.
prefetch_tensor
(
state
[
'state1'
])
if
'state2'
in
state
:
F
.
prefetch_tensor
(
state
[
'state2'
])
class
Optimizer2State
(
Optimizer8bit
):
class
Optimizer2State
(
Optimizer8bit
):
def
__init__
(
def
__init__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment