Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1b8fee8e
Unverified
Commit
1b8fee8e
authored
Aug 29, 2022
by
CsRic
Committed by
GitHub
Aug 29, 2022
Browse files
[FAW] shrink freq_cnter size (#1509)
parent
f8945eef
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
41 deletions
+31
-41
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+29
-40
tests/test_layers/test_cache_embedding.py
tests/test_layers/test_cache_embedding.py
+2
-1
No files found.
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
View file @
1b8fee8e
...
@@ -14,6 +14,7 @@ class EvictionStrategy(Enum):
...
@@ -14,6 +14,7 @@ class EvictionStrategy(Enum):
DATASET
=
2
DATASET
=
2
class
CachedParamMgr
(
torch
.
nn
.
Module
):
class
CachedParamMgr
(
torch
.
nn
.
Module
):
"""
"""
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
...
@@ -46,7 +47,6 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -46,7 +47,6 @@ class CachedParamMgr(torch.nn.Module):
self
.
cuda_row_num
=
cuda_row_num
self
.
cuda_row_num
=
cuda_row_num
self
.
_cuda_available_row_num
=
self
.
cuda_row_num
self
.
_cuda_available_row_num
=
self
.
cuda_row_num
self
.
pin_weight
=
pin_weight
self
.
pin_weight
=
pin_weight
self
.
elem_size_in_byte
=
weight
.
element_size
()
self
.
elem_size_in_byte
=
weight
.
element_size
()
# weight configure
# weight configure
...
@@ -61,31 +61,13 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -61,31 +61,13 @@ class CachedParamMgr(torch.nn.Module):
self
.
_evict_strategy
=
evict_strategy
self
.
_evict_strategy
=
evict_strategy
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
# cpu_row_idx -> frequency, freq of the cpu rows.
# cache_row_idx -> frequency, freq of the cache rows.
# evict the minimal freq value row in cuda cache.
# classic lfu cache. evict the minimal freq value row in cuda cache.
'''
The last element of `freq_cnter` is set to the maximum value of int.
The rows store nothing (not used) in the `self.cuda_weight` whose value is -1 in `self.cached_idx_map`.
In this way, the not used rows are placed at the end of the sorted.
'''
self
.
register_buffer
(
"freq_cnter"
,
self
.
register_buffer
(
"freq_cnter"
,
torch
.
empty
(
self
.
num_embeddings
+
1
,
torch
.
empty
(
self
.
cuda_row_num
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
long
).
fill_
(
0
),
dtype
=
torch
.
long
).
fill_
(
sys
.
maxsize
),
persistent
=
False
)
persistent
=
False
)
self
.
freq_cnter
[
-
1
]
=
sys
.
maxsize
def
_update_freq_cnter
(
self
,
cpu_row_idxs_original
:
torch
.
Tensor
)
->
None
:
"""_update_freq_cnter
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
Args:
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
"""
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
add_num
=
torch
.
bincount
(
cpu_row_idxs_original
)
self
.
freq_cnter
[:
add_num
.
shape
[
0
]]
+=
add_num
def
_find_evict_gpu_idxs
(
self
,
evict_num
:
int
)
->
torch
.
Tensor
:
def
_find_evict_gpu_idxs
(
self
,
evict_num
:
int
)
->
torch
.
Tensor
:
"""_find_evict_gpu_idxs
"""_find_evict_gpu_idxs
...
@@ -100,14 +82,15 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -100,14 +82,15 @@ class CachedParamMgr(torch.nn.Module):
"""
"""
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
# find the minimal evict_num freq entries in cached_idx_map
# find the minimal evict_num freq entries in cached_idx_map
evict_gpu_row_idxs
=
torch
.
argsort
(
self
.
freq_cnter
[
self
.
cached_idx_map
])[:
evict_num
]
_
,
evict_gpu_row_idxs
=
torch
.
topk
(
self
.
freq_cnter
,
evict_num
,
largest
=
False
)
return
evict_gpu_row_idxs
return
evict_gpu_row_idxs
elif
self
.
_evict_strategy
==
EvictionStrategy
.
DATASET
:
elif
self
.
_evict_strategy
==
EvictionStrategy
.
DATASET
:
# cached_idx_map itself implies the priority of eviction.
# cached_idx_map itself implies the priority of eviction.
# The value of self.cached_idx_map represents cpu_row_idx.
# The value of self.cached_idx_map represents cpu_row_idx.
# The larger it is, the less frequently it will appear in the dataset,
# The larger it is, the less frequently it will appear in the dataset,
# and the higher its eviction priority will be.
# and the higher its eviction priority will be.
return
torch
.
argsort
(
self
.
cached_idx_map
,
descending
=
True
)[:
evict_num
]
_
,
evict_gpu_row_idxs
=
torch
.
topk
(
self
.
cached_idx_map
,
evict_num
,
largest
=
True
)
return
evict_gpu_row_idxs
else
:
else
:
raise
TypeError
raise
TypeError
...
@@ -181,8 +164,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -181,8 +164,7 @@ class CachedParamMgr(torch.nn.Module):
Execute only once before training, also known as warmup phase.
Execute only once before training, also known as warmup phase.
:NOTE If you would like to use the DATASET as the eviction strategy, you must call this function.
:NOTE If you would like to use the DATASET as the eviction strategy, you must call this function.
:NOTE If you are use the LFU as the eviction strategy, you can skip this function. The `freq_cnter` will be initialized as all zeros.
:NOTE If you are use the LFU as the eviction strategy, you can skip this function.
You can also call this function to inialized the `freq_cnter` with dataset frequency statistics.
Args:
Args:
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
...
@@ -194,9 +176,6 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -194,9 +176,6 @@ class CachedParamMgr(torch.nn.Module):
tmp_idx
=
torch
.
argsort
(
ids_freq_mapping
,
descending
=
True
)
tmp_idx
=
torch
.
argsort
(
ids_freq_mapping
,
descending
=
True
)
sorted_idx
=
torch
.
argsort
(
tmp_idx
)
sorted_idx
=
torch
.
argsort
(
tmp_idx
)
self
.
idx_map
.
data
.
copy_
(
sorted_idx
)
self
.
idx_map
.
data
.
copy_
(
sorted_idx
)
#initialize freq_cnter if use LFU
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
[:
-
1
],
_
=
torch
.
sort
(
ids_freq_mapping
)
preload_row_num
=
min
(
int
(
np
.
ceil
(
self
.
cuda_row_num
*
warmup_ratio
)),
self
.
num_embeddings
)
preload_row_num
=
min
(
int
(
np
.
ceil
(
self
.
cuda_row_num
*
warmup_ratio
)),
self
.
num_embeddings
)
if
preload_row_num
>
0
:
if
preload_row_num
>
0
:
...
@@ -218,6 +197,8 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -218,6 +197,8 @@ class CachedParamMgr(torch.nn.Module):
# update auxiliary info
# update auxiliary info
slot_offsets
=
preload_slot_ids
slot_offsets
=
preload_slot_ids
self
.
cached_idx_map
[
preload_slot_ids
]
=
preload_slot_ids
self
.
cached_idx_map
[
preload_slot_ids
]
=
preload_slot_ids
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
.
index_fill_
(
0
,
preload_slot_ids
,
0
)
self
.
inverted_cached_idx
[
preload_slot_ids
]
=
slot_offsets
self
.
inverted_cached_idx
[
preload_slot_ids
]
=
slot_offsets
self
.
_cuda_available_row_num
-=
preload_row_num
self
.
_cuda_available_row_num
-=
preload_row_num
print
(
f
'Cache warmup finished cost
{
timer
.
elapsed
}
sec.'
)
print
(
f
'Cache warmup finished cost
{
timer
.
elapsed
}
sec.'
)
...
@@ -234,6 +215,8 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -234,6 +215,8 @@ class CachedParamMgr(torch.nn.Module):
self
.
inverted_cached_idx
.
index_fill_
(
0
,
row_ids
,
-
1
)
self
.
inverted_cached_idx
.
index_fill_
(
0
,
row_ids
,
-
1
)
self
.
_cuda_available_row_num
+=
slots
.
numel
()
self
.
_cuda_available_row_num
+=
slots
.
numel
()
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
.
fill_
(
sys
.
maxsize
)
assert
self
.
_cuda_available_row_num
==
self
.
cuda_row_num
assert
self
.
_cuda_available_row_num
==
self
.
cuda_row_num
assert
torch
.
all
(
self
.
inverted_cached_idx
==
-
1
).
item
()
assert
torch
.
all
(
self
.
inverted_cached_idx
==
-
1
).
item
()
assert
torch
.
all
(
self
.
cached_idx_map
==
-
1
).
item
()
assert
torch
.
all
(
self
.
cached_idx_map
==
-
1
).
item
()
...
@@ -275,8 +258,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -275,8 +258,7 @@ class CachedParamMgr(torch.nn.Module):
torch.Tensor: indices on the cuda_cached_weight.
torch.Tensor: indices on the cuda_cached_weight.
"""
"""
with
record_function
(
"(zhg) get unique indices"
):
with
record_function
(
"(zhg) get unique indices"
):
cpu_row_idxs_original
=
self
.
idx_map
.
index_select
(
0
,
ids
)
cpu_row_idxs
,
repeat_times
=
torch
.
unique
(
self
.
idx_map
.
index_select
(
0
,
ids
),
return_counts
=
True
)
cpu_row_idxs
=
torch
.
unique
(
cpu_row_idxs_original
)
assert
len
(
cpu_row_idxs
)
<=
self
.
cuda_row_num
,
\
assert
len
(
cpu_row_idxs
)
<=
self
.
cuda_row_num
,
\
f
"You move
{
len
(
cpu_row_idxs
)
}
embedding rows from CPU to CUDA. "
\
f
"You move
{
len
(
cpu_row_idxs
)
}
embedding rows from CPU to CUDA. "
\
...
@@ -301,7 +283,10 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -301,7 +283,10 @@ class CachedParamMgr(torch.nn.Module):
gpu_row_idxs
=
self
.
_id_to_cached_cuda_id
(
ids
)
gpu_row_idxs
=
self
.
_id_to_cached_cuda_id
(
ids
)
# update for LFU.
# update for LFU.
self
.
_update_freq_cnter
(
cpu_row_idxs_original
)
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
unique_gpu_row_idxs
=
self
.
inverted_cached_idx
[
cpu_row_idxs
]
self
.
freq_cnter
.
scatter_add_
(
0
,
unique_gpu_row_idxs
,
repeat_times
)
return
gpu_row_idxs
return
gpu_row_idxs
def
_reset_comm_stats
(
self
):
def
_reset_comm_stats
(
self
):
...
@@ -324,23 +309,21 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -324,23 +309,21 @@ class CachedParamMgr(torch.nn.Module):
if
evict_num
>
0
:
if
evict_num
>
0
:
with
Timer
()
as
timer
:
with
Timer
()
as
timer
:
mask_cpu_row_idx
=
torch
.
isin
(
self
.
cached_idx_map
,
self
.
evict_backlist
)
mask_cpu_row_idx
=
torch
.
isin
(
self
.
cached_idx_map
,
self
.
evict_backlist
)
invalid_idxs
=
torch
.
nonzero
(
mask_cpu_row_idx
).
squeeze
(
1
)
if
self
.
_evict_strategy
==
EvictionStrategy
.
DATASET
:
if
self
.
_evict_strategy
==
EvictionStrategy
.
DATASET
:
# mask method.
# mask method.
# set cached_idx_map[invalid_idxs] to -2.
# set cached_idx_map[invalid_idxs] to -2.
# so those idxs will be sorted to end, therefore not being chosen as victim
# so those idxs will be sorted to end, therefore not being chosen as victim
invalid_idxs
=
torch
.
nonzero
(
mask_cpu_row_idx
).
squeeze
(
1
)
backup_idxs
=
self
.
cached_idx_map
[
mask_cpu_row_idx
].
clone
()
backup_idxs
=
self
.
cached_idx_map
[
mask_cpu_row_idx
].
clone
()
self
.
cached_idx_map
.
index_fill_
(
0
,
invalid_idxs
,
-
2
)
self
.
cached_idx_map
.
index_fill_
(
0
,
invalid_idxs
,
-
2
)
evict_gpu_row_idxs
=
self
.
_find_evict_gpu_idxs
(
evict_num
)
evict_gpu_row_idxs
=
self
.
_find_evict_gpu_idxs
(
evict_num
)
self
.
cached_idx_map
.
index_copy_
(
0
,
invalid_idxs
,
backup_idxs
)
self
.
cached_idx_map
.
index_copy_
(
0
,
invalid_idxs
,
backup_idxs
)
elif
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
elif
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
invalid_idxs
=
torch
.
nonzero
(
mask_cpu_row_idx
).
squeeze
(
1
)
backup_freqs
=
self
.
freq_cnter
[
invalid_idxs
].
clone
()
backup_idxs
=
self
.
cached_idx_map
[
mask_cpu_row_idx
].
clone
()
self
.
freq_cnter
.
index_fill_
(
0
,
invalid_idxs
,
sys
.
maxsize
)
self
.
cached_idx_map
.
index_fill_
(
0
,
invalid_idxs
,
-
1
)
evict_gpu_row_idxs
=
self
.
_find_evict_gpu_idxs
(
evict_num
)
evict_gpu_row_idxs
=
self
.
_find_evict_gpu_idxs
(
evict_num
)
self
.
cached_idx_map
.
index_copy_
(
0
,
invalid_idxs
,
backup_
idx
s
)
self
.
freq_cnter
.
index_copy_
(
0
,
invalid_idxs
,
backup_
freq
s
)
evict_info
=
self
.
cached_idx_map
[
evict_gpu_row_idxs
]
evict_info
=
self
.
cached_idx_map
[
evict_gpu_row_idxs
]
...
@@ -357,6 +340,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -357,6 +340,7 @@ class CachedParamMgr(torch.nn.Module):
self
.
cached_idx_map
.
index_fill_
(
0
,
evict_gpu_row_idxs
,
-
1
)
self
.
cached_idx_map
.
index_fill_
(
0
,
evict_gpu_row_idxs
,
-
1
)
self
.
inverted_cached_idx
.
index_fill_
(
0
,
evict_info
,
-
1
)
self
.
inverted_cached_idx
.
index_fill_
(
0
,
evict_info
,
-
1
)
# self.freq_cnter.index_fill(0, evict_gpu_row_idxs, sys.maxsize) # unnecessary
self
.
_cuda_available_row_num
+=
evict_num
self
.
_cuda_available_row_num
+=
evict_num
weight_size
=
evict_gpu_row_idxs
.
numel
()
*
self
.
embedding_dim
weight_size
=
evict_gpu_row_idxs
.
numel
()
*
self
.
embedding_dim
...
@@ -379,6 +363,8 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -379,6 +363,8 @@ class CachedParamMgr(torch.nn.Module):
slot_offsets
=
slots
slot_offsets
=
slots
self
.
cached_idx_map
[
slots
]
=
cpu_row_idxs
self
.
cached_idx_map
[
slots
]
=
cpu_row_idxs
self
.
inverted_cached_idx
.
index_copy_
(
0
,
cpu_row_idxs
,
slot_offsets
)
self
.
inverted_cached_idx
.
index_copy_
(
0
,
cpu_row_idxs
,
slot_offsets
)
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
.
index_fill_
(
0
,
slots
,
0
)
self
.
_cuda_available_row_num
-=
cpu_row_idxs
.
numel
()
self
.
_cuda_available_row_num
-=
cpu_row_idxs
.
numel
()
self
.
_cpu_to_cuda_elpase
+=
timer
.
elapsed
self
.
_cpu_to_cuda_elpase
+=
timer
.
elapsed
weight_size
=
cpu_row_idxs
.
numel
()
*
self
.
embedding_dim
weight_size
=
cpu_row_idxs
.
numel
()
*
self
.
embedding_dim
...
@@ -421,7 +407,8 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -421,7 +407,8 @@ class CachedParamMgr(torch.nn.Module):
# update inverted_cached_idx, min_slot_id is evicted from cuda
# update inverted_cached_idx, min_slot_id is evicted from cuda
self
.
cached_idx_map
[
max_cpu_row_idx
]
=
-
1
self
.
cached_idx_map
[
max_cpu_row_idx
]
=
-
1
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
[
max_cpu_row_idx
]
=
sys
.
maxsize
self
.
inverted_cached_idx
[
max_gpu_row_idx
]
=
-
1
self
.
inverted_cached_idx
[
max_gpu_row_idx
]
=
-
1
self
.
_cuda_available_row_num
+=
1
self
.
_cuda_available_row_num
+=
1
...
@@ -456,6 +443,8 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -456,6 +443,8 @@ class CachedParamMgr(torch.nn.Module):
# update the inverted_cached_idx
# update the inverted_cached_idx
self
.
cached_idx_map
[
slot_id
]
=
row_id
self
.
cached_idx_map
[
slot_id
]
=
row_id
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
[
slot_id
]
=
0
self
.
inverted_cached_idx
[
row_id
]
=
slot_offset
self
.
inverted_cached_idx
[
row_id
]
=
slot_offset
self
.
_cuda_available_row_num
-=
1
self
.
_cuda_available_row_num
-=
1
...
...
tests/test_layers/test_cache_embedding.py
View file @
1b8fee8e
...
@@ -177,9 +177,10 @@ def test_lfu_strategy():
...
@@ -177,9 +177,10 @@ def test_lfu_strategy():
# check strategy
# check strategy
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
3
],
device
=
"cuda:0"
),
offsets
)
# miss, evict 1
Bag
.
forward
(
torch
.
tensor
([
3
],
device
=
"cuda:0"
),
offsets
)
# miss, evict 1
Bag
.
forward
(
torch
.
tensor
([
2
],
device
=
"cuda:0"
),
offsets
)
# hit
Bag
.
forward
(
torch
.
tensor
([
2
],
device
=
"cuda:0"
),
offsets
)
# hit
Bag
.
forward
(
torch
.
tensor
([
4
],
device
=
"cuda:0"
),
offsets
)
# miss, evict
1
Bag
.
forward
(
torch
.
tensor
([
4
],
device
=
"cuda:0"
),
offsets
)
# miss, evict
3
Bag
.
forward
(
torch
.
tensor
([
2
],
device
=
"cuda:0"
),
offsets
)
# hit
Bag
.
forward
(
torch
.
tensor
([
2
],
device
=
"cuda:0"
),
offsets
)
# hit
Bag
.
forward
(
torch
.
tensor
([
0
],
device
=
"cuda:0"
),
offsets
)
# hit
Bag
.
forward
(
torch
.
tensor
([
0
],
device
=
"cuda:0"
),
offsets
)
# hit
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment