Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ba61109b
Unverified
Commit
ba61109b
authored
Aug 26, 2022
by
Jiarui Fang
Committed by
GitHub
Aug 26, 2022
Browse files
[FAW] remove code related to chunk (#1501)
parent
d5085bb3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
32 deletions
+27
-32
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+26
-27
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py
...n/parallel/layers/cache_embedding/freq_aware_embedding.py
+1
-5
No files found.
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
View file @
ba61109b
...
...
@@ -56,7 +56,6 @@ class CachedParamMgr(torch.nn.Module):
self
.
num_hits_history
=
[]
self
.
num_miss_history
=
[]
self
.
num_write_back_history
=
[]
self
.
input_id_percent_in_load_chunk
=
[]
self
.
_reset_comm_stats
()
self
.
_evict_strategy
=
evict_strategy
...
...
@@ -156,23 +155,23 @@ class CachedParamMgr(torch.nn.Module):
# self.cuda_cached_weight = self.weight
raise
NotImplementedError
()
def
cpu_weight_data
(
self
,
chunk
_id
:
int
)
->
torch
.
Tensor
:
def
cpu_weight_data
(
self
,
row
_id
x
:
int
)
->
torch
.
Tensor
:
"""
access a
chunk
of CPU weight.
access a
row
of CPU weight.
Args:
chunk
_id (int):
chunk id
row
_id
x
(int):
the idx of rows
Returns:
torch.Tensor: a piece of memory in CPU weight corresponding to
chunk
id's payload. The tensor is 1-D.
torch.Tensor: a piece of memory in CPU weight corresponding to
row
id's payload. The tensor is 1-D.
"""
return
self
.
weight
.
data
.
view
(
-
1
).
narrow
(
0
,
int
(
chunk
_id
)
*
self
.
embedding_dim
,
int
(
row
_id
x
)
*
self
.
embedding_dim
,
self
.
embedding_dim
).
view
(
1
,
self
.
embedding_dim
)
@
property
def
cuda_available_
chunk
_num
(
self
):
def
cuda_available_
row
_num
(
self
):
return
self
.
_cuda_available_row_num
@
torch
.
no_grad
()
...
...
@@ -202,7 +201,7 @@ class CachedParamMgr(torch.nn.Module):
preload_row_num
=
min
(
int
(
np
.
ceil
(
self
.
cuda_row_num
*
warmup_ratio
)),
self
.
num_embeddings
)
if
preload_row_num
>
0
:
with
Timer
()
as
timer
:
# extract
chunk
s from cpu weight
# extract
row
s from cpu weight
preload_row_ids
=
torch
.
arange
(
preload_row_num
)
preload_slot_ids
=
preload_row_ids
.
cuda
()
...
...
@@ -213,8 +212,8 @@ class CachedParamMgr(torch.nn.Module):
src
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
),
tgt
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
))
else
:
preload_
chunk
s
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_select
(
0
,
preload_row_ids
).
cuda
()
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_copy_
(
0
,
preload_slot_ids
,
preload_
chunk
s
)
preload_
row
s
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_select
(
0
,
preload_row_ids
).
cuda
()
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_copy_
(
0
,
preload_slot_ids
,
preload_
row
s
)
# update auxiliary info
slot_offsets
=
preload_slot_ids
...
...
@@ -224,15 +223,15 @@ class CachedParamMgr(torch.nn.Module):
print
(
f
'Cache warmup finished cost
{
timer
.
elapsed
}
sec.'
)
def
flush
(
self
):
"""flush all CUDA
chunk
s to CPU.
"""flush all CUDA
row
s to CPU.
The function is usually called after training finished.
"""
slots
=
torch
.
nonzero
(
self
.
cached_idx_map
>
-
1
).
squeeze
(
1
)
chunk
_ids
=
self
.
cached_idx_map
[
slots
]
chunk
s
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_select
(
0
,
slots
).
cpu
()
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_copy_
(
0
,
chunk
_ids
.
cpu
(),
chunk
s
)
row
_ids
=
self
.
cached_idx_map
[
slots
]
row
s
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_select
(
0
,
slots
).
cpu
()
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_copy_
(
0
,
row
_ids
.
cpu
(),
row
s
)
self
.
cached_idx_map
.
index_fill_
(
0
,
slots
,
-
1
)
self
.
inverted_cached_idx
.
index_fill_
(
0
,
chunk
_ids
,
-
1
)
self
.
inverted_cached_idx
.
index_fill_
(
0
,
row
_ids
,
-
1
)
self
.
_cuda_available_row_num
+=
slots
.
numel
()
assert
self
.
_cuda_available_row_num
==
self
.
cuda_row_num
...
...
@@ -280,25 +279,25 @@ class CachedParamMgr(torch.nn.Module):
cpu_row_idxs
=
torch
.
unique
(
cpu_row_idxs_original
)
assert
len
(
cpu_row_idxs
)
<=
self
.
cuda_row_num
,
\
f
"
the input indices pull
{
len
(
cpu_row_idxs
)
}
chunks,
"
\
f
"
which
is larger than the
presented
{
self
.
cuda_row_num
}
, "
\
f
"
p
lease increase cuda_row_num
shrink
batch size"
f
"
You move
{
len
(
cpu_row_idxs
)
}
embedding rows from CPU to CUDA.
"
\
f
"
It
is larger than the
capacity of the cache, which at most contains
{
self
.
cuda_row_num
}
rows
, "
\
f
"
P
lease increase cuda_row_num
or decrease the training
batch size
.
"
self
.
evict_backlist
=
cpu_row_idxs
with
record_function
(
"(zhg) get cpu
chunk indice
s"
):
with
record_function
(
"(zhg) get cpu
row idx
s"
):
comm_cpu_row_idxs
=
cpu_row_idxs
[
torch
.
isin
(
cpu_row_idxs
,
self
.
cached_idx_map
,
invert
=
True
)]
self
.
num_hits_history
.
append
(
len
(
cpu_row_idxs
)
-
len
(
comm_cpu_row_idxs
))
self
.
num_miss_history
.
append
(
len
(
comm_cpu_row_idxs
))
self
.
num_write_back_history
.
append
(
0
)
# move sure the cuda
chunk
will not be evicted!
# move sure the cuda
rows
will not be evicted!
with
record_function
(
"(zhg) cache update"
):
self
.
_prepare_rows_on_cuda
(
comm_cpu_row_idxs
)
self
.
evict_backlist
=
torch
.
tensor
([],
device
=
cpu_row_idxs
.
device
,
dtype
=
cpu_row_idxs
.
dtype
)
# new ids chunk_offset + offset_in_chunk
with
record_function
(
"(zhg) embed idx -> cache
chunk
id"
):
with
record_function
(
"(zhg) embed
cpu rows
idx -> cache
gpu row
id
xs
"
):
gpu_row_idxs
=
self
.
_id_to_cached_cuda_id
(
ids
)
# update for LFU.
...
...
@@ -311,17 +310,17 @@ class CachedParamMgr(torch.nn.Module):
self
.
_cuda_to_cpu_elapse
=
0
self
.
_cuda_to_cpu_numel
=
0
def
_
chunk
_in_cuda
(
self
,
chunk
_id
:
int
)
->
bool
:
return
self
.
inverted_cached_idx
[
chunk
_id
]
!=
-
1
def
_
row
_in_cuda
(
self
,
row
_id
:
int
)
->
bool
:
return
self
.
inverted_cached_idx
[
row
_id
]
!=
-
1
@
torch
.
no_grad
()
def
_prepare_rows_on_cuda
(
self
,
cpu_row_idxs
:
torch
.
Tensor
)
->
None
:
"""prepare rows in cpu_row_idxs on CUDA memory
Args:
cpu_row_idxs (torch.Tensor): the
chunk
s to be placed on CUDA
cpu_row_idxs (torch.Tensor): the
row
s to be placed on CUDA
"""
evict_num
=
cpu_row_idxs
.
numel
()
-
self
.
cuda_available_
chunk
_num
evict_num
=
cpu_row_idxs
.
numel
()
-
self
.
cuda_available_
row
_num
if
evict_num
>
0
:
with
Timer
()
as
timer
:
mask_cpu_row_idx
=
torch
.
isin
(
self
.
cached_idx_map
,
self
.
evict_backlist
)
...
...
@@ -396,7 +395,7 @@ class CachedParamMgr(torch.nn.Module):
"""
deprecated
evict one
chunk
from cuda to cpu.
evict one
row
from cuda to cpu.
Returns:
(int) : the slot id be evicted.
"""
...
...
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py
View file @
ba61109b
...
...
@@ -119,8 +119,4 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
if
self
.
cache_weight_mgr
.
_cuda_to_cpu_numel
>
0
:
return
self
.
cache_weight_mgr
.
_cuda_to_cpu_numel
*
self
.
cache_weight_mgr
.
elem_size_in_byte
/
1e6
/
\
self
.
cache_weight_mgr
.
_cuda_to_cpu_elapse
return
0
@
property
def
input_id_percent_in_load_chunk
(
self
):
return
0
# np.mean(self.cache_weight_mgr.input_id_percent_in_load_chunk) * 100
return
0
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment