Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0ed2f461
Unverified
Commit
0ed2f461
authored
Aug 26, 2022
by
CsRic
Committed by
GitHub
Aug 26, 2022
Browse files
[FAW] FAW embedding use LRU as eviction strategy intialized with dataset stats (#1494)
parent
8b7d6bd5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
25 deletions
+40
-25
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+31
-20
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
...l/layers/cache_embedding/parallel_freq_aware_embedding.py
+3
-2
tests/test_layers/test_cache_embedding.py
tests/test_layers/test_cache_embedding.py
+6
-3
No files found.
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
View file @
0ed2f461
...
...
@@ -5,7 +5,7 @@ from typing import List, Optional
from
contexttimer
import
Timer
from
.copyer
import
LimitBuffIndexCopyer
from
enum
import
Enum
import
sys
class
EvictionStrategy
(
Enum
):
LFU
=
1
...
...
@@ -25,14 +25,14 @@ class CachedParamMgr(torch.nn.Module):
cuda_row_num
:
int
=
0
,
buffer_size
:
int
=
50_000
,
pin_weight
=
False
,
evict_strategy
=
EvictionStrategy
.
DATASET
)
->
None
:
evict_strategy
=
EvictionStrategy
.
DATASET
,
)
->
None
:
super
(
CachedParamMgr
,
self
).
__init__
()
self
.
buffer_size
=
buffer_size
self
.
num_embeddings
,
self
.
embedding_dim
=
weight
.
shape
self
.
cuda_row_num
=
cuda_row_num
self
.
_cuda_available_row_num
=
self
.
cuda_row_num
self
.
pin_weight
=
pin_weight
self
.
elem_size_in_byte
=
weight
.
element_size
()
# weight configure
...
...
@@ -50,12 +50,22 @@ class CachedParamMgr(torch.nn.Module):
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
# cpu_row_idx -> frequency, freq of the cpu rows.
# evict the minimal freq value row in cuda cache.
'''
during cache eviction, if a cached_idx_map element maps to a masked cpu_idx, we re-map that element to -1 temporary.
also, disabled cached_idx_map element maps to -1 by default.
freq_cnter[-1], the last element, should ALWAYS be MAX VALUE so those masked or disabled idxs will be argsorted to end,
not being chosen to evict.
ZH: freq_cnter的最后一位设为了最大值, 不该被选为换出的cache idx都是-1, 指向这个最大值, 所以排序时在队尾, 不会被选中换出
'''
self
.
register_buffer
(
"freq_cnter"
,
torch
.
empty
(
self
.
num_embeddings
,
device
=
torch
.
cuda
.
current_device
(),
torch
.
empty
(
self
.
num_embeddings
+
1
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
long
).
fill_
(
0
),
persistent
=
False
)
self
.
freq_cnter
[
-
1
]
=
sys
.
maxsize
def
_update_freq_cnter
(
self
,
cpu_row_idxs
:
torch
.
Tensor
)
->
None
:
def
_update_freq_cnter
(
self
,
cpu_row_idxs
_original
:
torch
.
Tensor
)
->
None
:
"""_update_freq_cnter
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
...
...
@@ -64,7 +74,8 @@ class CachedParamMgr(torch.nn.Module):
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
"""
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
[
cpu_row_idxs
]
+=
1
add_num
=
torch
.
bincount
(
cpu_row_idxs_original
)
self
.
freq_cnter
[:
add_num
.
shape
[
0
]]
+=
add_num
def
_find_evict_gpu_idxs
(
self
,
evict_num
:
int
)
->
torch
.
Tensor
:
"""_find_evict_gpu_idxs
...
...
@@ -165,10 +176,13 @@ class CachedParamMgr(torch.nn.Module):
warmup_ratio (float): the amount of chunks preloaded in cuda cache
"""
if
ids_freq_mapping
is
not
None
:
ids_freq_mapping
=
torch
.
tensor
(
ids_freq_mapping
)
tmp_idx
=
torch
.
argsort
(
ids_freq_mapping
,
descending
=
True
)
sorted_idx
=
torch
.
argsort
(
tmp_idx
)
self
.
idx_map
.
data
.
copy_
(
sorted_idx
)
#initialize freq_cnter if use LFU
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
[:
-
1
],
_
=
torch
.
sort
(
ids_freq_mapping
)
# TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks.
# As cuda_cached_weight is very big. You may not have that much available memory!
# Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda
...
...
@@ -249,8 +263,9 @@ class CachedParamMgr(torch.nn.Module):
torch.Tensor: indices on the cuda_cached_weight.
"""
with
record_function
(
"(zhg) get unique indices"
):
cpu_row_idxs
=
torch
.
unique
(
self
.
idx_map
.
index_select
(
0
,
ids
))
cpu_row_idxs_original
=
self
.
idx_map
.
index_select
(
0
,
ids
)
cpu_row_idxs
=
torch
.
unique
(
cpu_row_idxs_original
)
assert
len
(
cpu_row_idxs
)
<=
self
.
cuda_row_num
,
\
f
"the input indices pull
{
len
(
cpu_row_idxs
)
}
chunks, "
\
f
"which is larger than the presented
{
self
.
cuda_row_num
}
, "
\
...
...
@@ -272,10 +287,9 @@ class CachedParamMgr(torch.nn.Module):
# new ids chunk_offset + offset_in_chunk
with
record_function
(
"(zhg) embed idx -> cache chunk id"
):
gpu_row_idxs
=
self
.
_id_to_cached_cuda_id
(
ids
)
# update for LFU.
self
.
_update_freq_cnter
(
cpu_row_idxs
)
self
.
_update_freq_cnter
(
cpu_row_idxs_original
)
return
gpu_row_idxs
def
_reset_comm_stats
(
self
):
...
...
@@ -298,26 +312,23 @@ class CachedParamMgr(torch.nn.Module):
if
evict_num
>
0
:
with
Timer
()
as
timer
:
mask_cpu_row_idx
=
torch
.
isin
(
self
.
cached_idx_map
,
self
.
evict_backlist
)
invalid_idxs
=
torch
.
nonzero
(
mask_cpu_row_idx
).
squeeze
(
1
)
if
self
.
_evict_strategy
==
EvictionStrategy
.
DATASET
:
# mask method.
# set cached_idx_map[invalid_idxs] to -2.
# so those idxs will be sorted to end, therefore not being chosen as victim
invalid_idxs
=
torch
.
nonzero
(
mask_cpu_row_idx
).
squeeze
(
1
)
backup_idxs
=
self
.
cached_idx_map
[
mask_cpu_row_idx
].
clone
()
self
.
cached_idx_map
.
index_fill_
(
0
,
invalid_idxs
,
-
2
)
evict_gpu_row_idxs
=
self
.
_find_evict_gpu_idxs
(
evict_num
)
self
.
cached_idx_map
.
index_copy_
(
0
,
invalid_idxs
,
backup_idxs
)
elif
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
# another mask method.
# set freq_cnter[invalid_idxs] to max
# so those idxs will be sorted to end, therefore not being chosen as victim
backup_cnter
=
self
.
freq_cnter
[
invalid_idxs
].
clone
()
self
.
freq_cnter
.
index_fill_
(
0
,
invalid_idxs
,
torch
.
max
(
self
.
freq_cnter
)
+
1
)
# or can we use a confident max value?
invalid_idxs
=
torch
.
nonzero
(
mask_cpu_row_idx
).
squeeze
(
1
)
backup_idxs
=
self
.
cached_idx_map
[
mask_cpu_row_idx
].
clone
()
self
.
cached_idx_map
.
index_fill_
(
0
,
invalid_idxs
,
-
1
)
evict_gpu_row_idxs
=
self
.
_find_evict_gpu_idxs
(
evict_num
)
self
.
freq_cnter
.
index_copy_
(
0
,
invalid_idxs
,
backup_
cnter
)
self
.
cached_idx_map
.
index_copy_
(
0
,
invalid_idxs
,
backup_
idxs
)
evict_info
=
self
.
cached_idx_map
[
evict_gpu_row_idxs
]
...
...
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
View file @
0ed2f461
...
...
@@ -6,7 +6,7 @@ from .freq_aware_embedding import FreqAwareEmbeddingBag
from
colossalai.nn._ops._utils
import
dual_all_to_all
from
colossalai.tensor
import
ColoParameter
,
ShardSpec
,
ComputePattern
,
ProcessGroup
,
ColoTensorSpec
,
ColoTensor
from
.cache_mgr
import
CachedParamMgr
,
EvictionStrategy
def
get_partition
(
embedding_dim
,
rank
,
world_size
)
->
Tuple
[
int
,
int
,
bool
]:
if
world_size
==
1
:
...
...
@@ -48,6 +48,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
warmup_ratio
=
0.7
,
buffer_size
=
50_000
,
pin_weight
=
False
,
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
DATASET
):
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
world_size
=
torch
.
distributed
.
get_world_size
()
...
...
@@ -59,7 +60,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
super
(
ParallelFreqAwareEmbeddingBag
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
,
mode
,
include_last_offset
,
dtype
,
device
,
cuda_row_num
,
ids_freq_mapping
,
warmup_ratio
,
buffer_size
,
pin_weight
)
warmup_ratio
,
buffer_size
,
pin_weight
,
evict_strategy
)
def
_weight_alloc
(
self
,
dtype
,
device
):
weight
=
torch
.
empty
(
self
.
num_embeddings
,
self
.
embedding_dim_per_partition
,
device
=
device
,
dtype
=
dtype
)
...
...
tests/test_layers/test_cache_embedding.py
View file @
0ed2f461
...
...
@@ -159,6 +159,9 @@ def test_lfu_strategy():
offsets
=
torch
.
tensor
([
0
],
device
=
"cuda:0"
)
# prepare frequency learning info:
Bag
.
forward
(
torch
.
tensor
([
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
...
...
@@ -182,7 +185,7 @@ def test_lfu_strategy():
assert
torch
.
allclose
(
torch
.
Tensor
(
Bag
.
cache_weight_mgr
.
num_hits_history
[
-
6
:]),
torch
.
Tensor
([
3
,
0
,
1
,
0
,
1
,
1
])),
\
"LFU strategy behavior failed"
def
gather_tensor
(
tensor
,
rank
,
world_size
):
gather_list
=
[]
if
rank
==
0
:
...
...
@@ -273,6 +276,6 @@ def test_parallel_freq_aware_embed(world_size):
if
__name__
==
'__main__'
:
test_freq_aware_embed
(
True
)
#
test_freq_aware_embed(True)
# test_parallel_freq_aware_embed(2)
# test_lfu_strategy()
\ No newline at end of file
test_lfu_strategy
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment