Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0ed2f461
Unverified
Commit
0ed2f461
authored
Aug 26, 2022
by
CsRic
Committed by
GitHub
Aug 26, 2022
Browse files
[FAW] FAW embedding use LRU as eviction strategy intialized with dataset stats (#1494)
parent
8b7d6bd5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
25 deletions
+40
-25
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+31
-20
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
...l/layers/cache_embedding/parallel_freq_aware_embedding.py
+3
-2
tests/test_layers/test_cache_embedding.py
tests/test_layers/test_cache_embedding.py
+6
-3
No files found.
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
View file @
0ed2f461
...
@@ -5,7 +5,7 @@ from typing import List, Optional
...
@@ -5,7 +5,7 @@ from typing import List, Optional
from
contexttimer
import
Timer
from
contexttimer
import
Timer
from
.copyer
import
LimitBuffIndexCopyer
from
.copyer
import
LimitBuffIndexCopyer
from
enum
import
Enum
from
enum
import
Enum
import
sys
class
EvictionStrategy
(
Enum
):
class
EvictionStrategy
(
Enum
):
LFU
=
1
LFU
=
1
...
@@ -25,14 +25,14 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -25,14 +25,14 @@ class CachedParamMgr(torch.nn.Module):
cuda_row_num
:
int
=
0
,
cuda_row_num
:
int
=
0
,
buffer_size
:
int
=
50_000
,
buffer_size
:
int
=
50_000
,
pin_weight
=
False
,
pin_weight
=
False
,
evict_strategy
=
EvictionStrategy
.
DATASET
)
->
None
:
evict_strategy
=
EvictionStrategy
.
DATASET
,
)
->
None
:
super
(
CachedParamMgr
,
self
).
__init__
()
super
(
CachedParamMgr
,
self
).
__init__
()
self
.
buffer_size
=
buffer_size
self
.
buffer_size
=
buffer_size
self
.
num_embeddings
,
self
.
embedding_dim
=
weight
.
shape
self
.
num_embeddings
,
self
.
embedding_dim
=
weight
.
shape
self
.
cuda_row_num
=
cuda_row_num
self
.
cuda_row_num
=
cuda_row_num
self
.
_cuda_available_row_num
=
self
.
cuda_row_num
self
.
_cuda_available_row_num
=
self
.
cuda_row_num
self
.
pin_weight
=
pin_weight
self
.
pin_weight
=
pin_weight
self
.
elem_size_in_byte
=
weight
.
element_size
()
self
.
elem_size_in_byte
=
weight
.
element_size
()
# weight configure
# weight configure
...
@@ -50,12 +50,22 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -50,12 +50,22 @@ class CachedParamMgr(torch.nn.Module):
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
# cpu_row_idx -> frequency, freq of the cpu rows.
# cpu_row_idx -> frequency, freq of the cpu rows.
# evict the minimal freq value row in cuda cache.
# evict the minimal freq value row in cuda cache.
'''
during cache eviction, if a cached_idx_map element maps to a masked cpu_idx, we re-map that element to -1 temporary.
also, disabled cached_idx_map element maps to -1 by default.
freq_cnter[-1], the last element, should ALWAYS be MAX VALUE so those masked or disabled idxs will be argsorted to end,
not being chosen to evict.
ZH: freq_cnter的最后一位设为了最大值, 不该被选为换出的cache idx都是-1, 指向这个最大值, 所以排序时在队尾, 不会被选中换出
'''
self
.
register_buffer
(
"freq_cnter"
,
self
.
register_buffer
(
"freq_cnter"
,
torch
.
empty
(
self
.
num_embeddings
,
device
=
torch
.
cuda
.
current_device
(),
torch
.
empty
(
self
.
num_embeddings
+
1
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
long
).
fill_
(
0
),
dtype
=
torch
.
long
).
fill_
(
0
),
persistent
=
False
)
persistent
=
False
)
self
.
freq_cnter
[
-
1
]
=
sys
.
maxsize
def
_update_freq_cnter
(
self
,
cpu_row_idxs
:
torch
.
Tensor
)
->
None
:
def
_update_freq_cnter
(
self
,
cpu_row_idxs
_original
:
torch
.
Tensor
)
->
None
:
"""_update_freq_cnter
"""_update_freq_cnter
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
...
@@ -64,7 +74,8 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -64,7 +74,8 @@ class CachedParamMgr(torch.nn.Module):
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
"""
"""
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
[
cpu_row_idxs
]
+=
1
add_num
=
torch
.
bincount
(
cpu_row_idxs_original
)
self
.
freq_cnter
[:
add_num
.
shape
[
0
]]
+=
add_num
def
_find_evict_gpu_idxs
(
self
,
evict_num
:
int
)
->
torch
.
Tensor
:
def
_find_evict_gpu_idxs
(
self
,
evict_num
:
int
)
->
torch
.
Tensor
:
"""_find_evict_gpu_idxs
"""_find_evict_gpu_idxs
...
@@ -165,10 +176,13 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -165,10 +176,13 @@ class CachedParamMgr(torch.nn.Module):
warmup_ratio (float): the amount of chunks preloaded in cuda cache
warmup_ratio (float): the amount of chunks preloaded in cuda cache
"""
"""
if
ids_freq_mapping
is
not
None
:
if
ids_freq_mapping
is
not
None
:
ids_freq_mapping
=
torch
.
tensor
(
ids_freq_mapping
)
tmp_idx
=
torch
.
argsort
(
ids_freq_mapping
,
descending
=
True
)
tmp_idx
=
torch
.
argsort
(
ids_freq_mapping
,
descending
=
True
)
sorted_idx
=
torch
.
argsort
(
tmp_idx
)
sorted_idx
=
torch
.
argsort
(
tmp_idx
)
self
.
idx_map
.
data
.
copy_
(
sorted_idx
)
self
.
idx_map
.
data
.
copy_
(
sorted_idx
)
#initialize freq_cnter if use LFU
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
[:
-
1
],
_
=
torch
.
sort
(
ids_freq_mapping
)
# TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks.
# TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks.
# As cuda_cached_weight is very big. You may not have that much available memory!
# As cuda_cached_weight is very big. You may not have that much available memory!
# Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda
# Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda
...
@@ -249,8 +263,9 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -249,8 +263,9 @@ class CachedParamMgr(torch.nn.Module):
torch.Tensor: indices on the cuda_cached_weight.
torch.Tensor: indices on the cuda_cached_weight.
"""
"""
with
record_function
(
"(zhg) get unique indices"
):
with
record_function
(
"(zhg) get unique indices"
):
cpu_row_idxs
=
torch
.
unique
(
self
.
idx_map
.
index_select
(
0
,
ids
))
cpu_row_idxs_original
=
self
.
idx_map
.
index_select
(
0
,
ids
)
cpu_row_idxs
=
torch
.
unique
(
cpu_row_idxs_original
)
assert
len
(
cpu_row_idxs
)
<=
self
.
cuda_row_num
,
\
assert
len
(
cpu_row_idxs
)
<=
self
.
cuda_row_num
,
\
f
"the input indices pull
{
len
(
cpu_row_idxs
)
}
chunks, "
\
f
"the input indices pull
{
len
(
cpu_row_idxs
)
}
chunks, "
\
f
"which is larger than the presented
{
self
.
cuda_row_num
}
, "
\
f
"which is larger than the presented
{
self
.
cuda_row_num
}
, "
\
...
@@ -272,10 +287,9 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -272,10 +287,9 @@ class CachedParamMgr(torch.nn.Module):
# new ids chunk_offset + offset_in_chunk
# new ids chunk_offset + offset_in_chunk
with
record_function
(
"(zhg) embed idx -> cache chunk id"
):
with
record_function
(
"(zhg) embed idx -> cache chunk id"
):
gpu_row_idxs
=
self
.
_id_to_cached_cuda_id
(
ids
)
gpu_row_idxs
=
self
.
_id_to_cached_cuda_id
(
ids
)
# update for LFU.
# update for LFU.
self
.
_update_freq_cnter
(
cpu_row_idxs
)
self
.
_update_freq_cnter
(
cpu_row_idxs_original
)
return
gpu_row_idxs
return
gpu_row_idxs
def
_reset_comm_stats
(
self
):
def
_reset_comm_stats
(
self
):
...
@@ -298,26 +312,23 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -298,26 +312,23 @@ class CachedParamMgr(torch.nn.Module):
if
evict_num
>
0
:
if
evict_num
>
0
:
with
Timer
()
as
timer
:
with
Timer
()
as
timer
:
mask_cpu_row_idx
=
torch
.
isin
(
self
.
cached_idx_map
,
self
.
evict_backlist
)
mask_cpu_row_idx
=
torch
.
isin
(
self
.
cached_idx_map
,
self
.
evict_backlist
)
invalid_idxs
=
torch
.
nonzero
(
mask_cpu_row_idx
).
squeeze
(
1
)
if
self
.
_evict_strategy
==
EvictionStrategy
.
DATASET
:
if
self
.
_evict_strategy
==
EvictionStrategy
.
DATASET
:
# mask method.
# mask method.
# set cached_idx_map[invalid_idxs] to -2.
# set cached_idx_map[invalid_idxs] to -2.
# so those idxs will be sorted to end, therefore not being chosen as victim
# so those idxs will be sorted to end, therefore not being chosen as victim
invalid_idxs
=
torch
.
nonzero
(
mask_cpu_row_idx
).
squeeze
(
1
)
backup_idxs
=
self
.
cached_idx_map
[
mask_cpu_row_idx
].
clone
()
backup_idxs
=
self
.
cached_idx_map
[
mask_cpu_row_idx
].
clone
()
self
.
cached_idx_map
.
index_fill_
(
0
,
invalid_idxs
,
-
2
)
self
.
cached_idx_map
.
index_fill_
(
0
,
invalid_idxs
,
-
2
)
evict_gpu_row_idxs
=
self
.
_find_evict_gpu_idxs
(
evict_num
)
evict_gpu_row_idxs
=
self
.
_find_evict_gpu_idxs
(
evict_num
)
self
.
cached_idx_map
.
index_copy_
(
0
,
invalid_idxs
,
backup_idxs
)
self
.
cached_idx_map
.
index_copy_
(
0
,
invalid_idxs
,
backup_idxs
)
elif
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
elif
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
# another mask method.
invalid_idxs
=
torch
.
nonzero
(
mask_cpu_row_idx
).
squeeze
(
1
)
# set freq_cnter[invalid_idxs] to max
backup_idxs
=
self
.
cached_idx_map
[
mask_cpu_row_idx
].
clone
()
# so those idxs will be sorted to end, therefore not being chosen as victim
self
.
cached_idx_map
.
index_fill_
(
0
,
invalid_idxs
,
-
1
)
backup_cnter
=
self
.
freq_cnter
[
invalid_idxs
].
clone
()
self
.
freq_cnter
.
index_fill_
(
0
,
invalid_idxs
,
torch
.
max
(
self
.
freq_cnter
)
+
1
)
# or can we use a confident max value?
evict_gpu_row_idxs
=
self
.
_find_evict_gpu_idxs
(
evict_num
)
evict_gpu_row_idxs
=
self
.
_find_evict_gpu_idxs
(
evict_num
)
self
.
freq_cnter
.
index_copy_
(
0
,
invalid_idxs
,
backup_
cnter
)
self
.
cached_idx_map
.
index_copy_
(
0
,
invalid_idxs
,
backup_
idxs
)
evict_info
=
self
.
cached_idx_map
[
evict_gpu_row_idxs
]
evict_info
=
self
.
cached_idx_map
[
evict_gpu_row_idxs
]
...
...
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
View file @
0ed2f461
...
@@ -6,7 +6,7 @@ from .freq_aware_embedding import FreqAwareEmbeddingBag
...
@@ -6,7 +6,7 @@ from .freq_aware_embedding import FreqAwareEmbeddingBag
from
colossalai.nn._ops._utils
import
dual_all_to_all
from
colossalai.nn._ops._utils
import
dual_all_to_all
from
colossalai.tensor
import
ColoParameter
,
ShardSpec
,
ComputePattern
,
ProcessGroup
,
ColoTensorSpec
,
ColoTensor
from
colossalai.tensor
import
ColoParameter
,
ShardSpec
,
ComputePattern
,
ProcessGroup
,
ColoTensorSpec
,
ColoTensor
from
.cache_mgr
import
CachedParamMgr
,
EvictionStrategy
def
get_partition
(
embedding_dim
,
rank
,
world_size
)
->
Tuple
[
int
,
int
,
bool
]:
def
get_partition
(
embedding_dim
,
rank
,
world_size
)
->
Tuple
[
int
,
int
,
bool
]:
if
world_size
==
1
:
if
world_size
==
1
:
...
@@ -48,6 +48,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
...
@@ -48,6 +48,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
warmup_ratio
=
0.7
,
warmup_ratio
=
0.7
,
buffer_size
=
50_000
,
buffer_size
=
50_000
,
pin_weight
=
False
,
pin_weight
=
False
,
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
DATASET
):
):
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
world_size
=
torch
.
distributed
.
get_world_size
()
self
.
world_size
=
torch
.
distributed
.
get_world_size
()
...
@@ -59,7 +60,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
...
@@ -59,7 +60,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
super
(
ParallelFreqAwareEmbeddingBag
,
super
(
ParallelFreqAwareEmbeddingBag
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
,
mode
,
include_last_offset
,
dtype
,
device
,
cuda_row_num
,
ids_freq_mapping
,
sparse
,
_weight
,
mode
,
include_last_offset
,
dtype
,
device
,
cuda_row_num
,
ids_freq_mapping
,
warmup_ratio
,
buffer_size
,
pin_weight
)
warmup_ratio
,
buffer_size
,
pin_weight
,
evict_strategy
)
def
_weight_alloc
(
self
,
dtype
,
device
):
def
_weight_alloc
(
self
,
dtype
,
device
):
weight
=
torch
.
empty
(
self
.
num_embeddings
,
self
.
embedding_dim_per_partition
,
device
=
device
,
dtype
=
dtype
)
weight
=
torch
.
empty
(
self
.
num_embeddings
,
self
.
embedding_dim_per_partition
,
device
=
device
,
dtype
=
dtype
)
...
...
tests/test_layers/test_cache_embedding.py
View file @
0ed2f461
...
@@ -159,6 +159,9 @@ def test_lfu_strategy():
...
@@ -159,6 +159,9 @@ def test_lfu_strategy():
offsets
=
torch
.
tensor
([
0
],
device
=
"cuda:0"
)
offsets
=
torch
.
tensor
([
0
],
device
=
"cuda:0"
)
# prepare frequency learning info:
# prepare frequency learning info:
Bag
.
forward
(
torch
.
tensor
([
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
...
@@ -182,7 +185,7 @@ def test_lfu_strategy():
...
@@ -182,7 +185,7 @@ def test_lfu_strategy():
assert
torch
.
allclose
(
torch
.
Tensor
(
Bag
.
cache_weight_mgr
.
num_hits_history
[
-
6
:]),
torch
.
Tensor
([
3
,
0
,
1
,
0
,
1
,
1
])),
\
assert
torch
.
allclose
(
torch
.
Tensor
(
Bag
.
cache_weight_mgr
.
num_hits_history
[
-
6
:]),
torch
.
Tensor
([
3
,
0
,
1
,
0
,
1
,
1
])),
\
"LFU strategy behavior failed"
"LFU strategy behavior failed"
def
gather_tensor
(
tensor
,
rank
,
world_size
):
def
gather_tensor
(
tensor
,
rank
,
world_size
):
gather_list
=
[]
gather_list
=
[]
if
rank
==
0
:
if
rank
==
0
:
...
@@ -273,6 +276,6 @@ def test_parallel_freq_aware_embed(world_size):
...
@@ -273,6 +276,6 @@ def test_parallel_freq_aware_embed(world_size):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_freq_aware_embed
(
True
)
#
test_freq_aware_embed(True)
# test_parallel_freq_aware_embed(2)
# test_parallel_freq_aware_embed(2)
# test_lfu_strategy()
test_lfu_strategy
()
\ No newline at end of file
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment