Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
af5438ca
Unverified
Commit
af5438ca
authored
Aug 29, 2022
by
Jiarui Fang
Committed by
GitHub
Aug 29, 2022
Browse files
[FAW] refactor reorder() for CachedParamMgr (#1514)
parent
9feee6d0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
51 deletions
+63
-51
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+24
-15
tests/test_layers/test_cache_embedding.py
tests/test_layers/test_cache_embedding.py
+39
-36
No files found.
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
View file @
af5438ca
...
...
@@ -172,44 +172,53 @@ class CachedParamMgr(torch.nn.Module):
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
warmup_ratio (float): the amount of chunks preloaded in cuda cache
"""
if
ids_freq_mapping
is
not
None
:
if
not
isinstance
(
ids_freq_mapping
,
torch
.
Tensor
):
ids_freq_mapping
=
torch
.
tensor
(
ids_freq_mapping
)
tmp_idx
=
torch
.
argsort
(
ids_freq_mapping
,
descending
=
True
)
sorted_idx
=
torch
.
argsort
(
tmp_idx
)
self
.
idx_map
.
data
.
copy_
(
sorted_idx
)
# reorder phase: reorder the cpu weight according to their freq stats in the target dataset.
# reorder only works for DATASET eviction strategy.
if
ids_freq_mapping
is
not
None
and
not
isinstance
(
ids_freq_mapping
,
torch
.
Tensor
):
ids_freq_mapping
=
torch
.
tensor
(
ids_freq_mapping
)
if
self
.
_evict_strategy
==
EvictionStrategy
.
DATASET
:
if
ids_freq_mapping
is
not
None
:
tmp_idx
=
torch
.
argsort
(
ids_freq_mapping
,
descending
=
True
)
sorted_idx
=
torch
.
argsort
(
tmp_idx
)
self
.
idx_map
.
data
.
copy_
(
sorted_idx
)
# warmup phase: copy #preload_row_num rows from cpu to gpu.
preload_row_num
=
min
(
int
(
np
.
ceil
(
self
.
cuda_row_num
*
warmup_ratio
)),
self
.
num_embeddings
)
if
preload_row_num
>
0
:
with
Timer
()
as
timer
:
# extract rows from cpu weight
preload_row_ids
=
torch
.
arange
(
preload_row_num
)
preload_cuda_row_idxs
=
preload_row_ids
.
cuda
()
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
and
ids_freq_mapping
is
not
None
:
freq_value
,
preload_cpu_ids
=
torch
.
topk
(
ids_freq_mapping
,
preload_row_num
,
dim
=
0
,
largest
=
True
)
preload_cuda_row_idxs
=
torch
.
arange
(
preload_row_num
).
cuda
()
else
:
preload_cpu_ids
=
torch
.
arange
(
preload_row_num
)
preload_cuda_row_idxs
=
preload_cpu_ids
.
cuda
()
if
self
.
buffer_size
>
0
:
self
.
limit_buff_index_copyer
.
index_copy
(
0
,
src_index
=
preload_
row
_ids
,
src_index
=
preload_
cpu
_ids
,
tgt_index
=
preload_cuda_row_idxs
,
src
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
),
tgt
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
))
else
:
preload_rows
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_select
(
0
,
preload_
row
_ids
).
cuda
()
preload_rows
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_select
(
0
,
preload_
cpu
_ids
).
cuda
()
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_copy_
(
0
,
preload_cuda_row_idxs
,
preload_rows
)
# update auxiliary info
slot_offsets
=
preload_cuda_row_idxs
self
.
cached_idx_map
[
preload_cuda_row_idxs
]
=
preload_cuda_row_idxs
self
.
cached_idx_map
[
preload_cuda_row_idxs
]
=
preload_cpu_ids
.
cuda
()
self
.
inverted_cached_idx
[
preload_cpu_ids
]
=
preload_cuda_row_idxs
self
.
_cuda_available_row_num
-=
preload_row_num
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
# if the ids_freq_mapping is not None, we initialize the embedding row's freq value in LFU as its freq in dataset.
if
ids_freq_mapping
is
None
:
self
.
freq_cnter
.
index_fill_
(
0
,
preload_cuda_row_idxs
,
0
)
else
:
self
.
freq_cnter
.
index_fill_
(
0
,
preload_cuda_row_idxs
,
self
.
idx_map
[
preload_cuda_row_idxs
]
)
self
.
freq_cnter
[
preload_cuda_row_idxs
]
=
freq_value
.
cuda
(
)
self
.
inverted_cached_idx
[
preload_cuda_row_idxs
]
=
slot_offsets
self
.
_cuda_available_row_num
-=
preload_row_num
print
(
f
'Cache warmup finished cost
{
timer
.
elapsed
}
sec.'
)
def
flush
(
self
):
...
...
tests/test_layers/test_cache_embedding.py
View file @
af5438ca
...
...
@@ -144,49 +144,52 @@ def test_freq_aware_embed(use_LFU: bool):
assert
torch
.
allclose
(
model_weight
,
ref_weight
),
\
f
"model weight:
{
model_weight
[
10
:
18
,
:
8
]
}
, reference:
{
ref_weight
[
10
:
18
,
:
8
]
}
"
def
test_lfu_strategy
():
@
pytest
.
mark
.
parametrize
(
'init_freq'
,
[
True
,
False
])
def
test_lfu_strategy
(
init_freq
:
bool
):
# minimal test to check behavior
Bag
=
FreqAwareEmbeddingBag
(
5
,
5
,
cuda_row_num
=
3
,
buffer_size
=
0
,
pin_weight
=
Tru
e
,
warmup_ratio
=
0
.0
,
evict_strategy
=
EvictionStrategy
.
LFU
)
offsets
=
torch
.
tensor
([
0
],
device
=
"cuda:0"
)
Bag
=
FreqAwareEmbeddingBag
(
5
,
5
,
cuda_row_num
=
3
,
buffer_size
=
0
,
pin_weight
=
True
,
ids_freq_mapping
=
[
4
,
2
,
1
,
3
,
1
]
if
init_freq
else
Non
e
,
warmup_ratio
=
1
.0
,
evict_strategy
=
EvictionStrategy
.
LFU
)
# print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map)
offsets
=
torch
.
tensor
([
0
],
device
=
"cuda:0"
)
# prepare frequency learning info:
Bag
.
forward
(
torch
.
tensor
([
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
],
device
=
"cuda:0"
),
offsets
)
# check strategy
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
3
],
device
=
"cuda:0"
),
offsets
)
# miss, evict 1
Bag
.
forward
(
torch
.
tensor
([
2
],
device
=
"cuda:0"
),
offsets
)
# hit
Bag
.
forward
(
torch
.
tensor
([
4
],
device
=
"cuda:0"
),
offsets
)
# miss, evict 3
Bag
.
forward
(
torch
.
tensor
([
2
],
device
=
"cuda:0"
),
offsets
)
# hit
Bag
.
forward
(
torch
.
tensor
([
0
],
device
=
"cuda:0"
),
offsets
)
# hit
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
0
,
1
,
2
],
device
=
"cuda:0"
),
offsets
)
Bag
.
forward
(
torch
.
tensor
([
3
],
device
=
"cuda:0"
),
offsets
)
# miss, evict 1
Bag
.
forward
(
torch
.
tensor
([
2
],
device
=
"cuda:0"
),
offsets
)
# hit
Bag
.
forward
(
torch
.
tensor
([
4
],
device
=
"cuda:0"
),
offsets
)
# miss, evict 3
Bag
.
forward
(
torch
.
tensor
([
2
],
device
=
"cuda:0"
),
offsets
)
# hit
Bag
.
forward
(
torch
.
tensor
([
0
],
device
=
"cuda:0"
),
offsets
)
# hit
assert
torch
.
allclose
(
torch
.
Tensor
(
Bag
.
cache_weight_mgr
.
num_hits_history
[
-
6
:]),
torch
.
Tensor
([
3
,
0
,
1
,
0
,
1
,
1
])),
\
"LFU strategy behavior failed"
def
gather_tensor
(
tensor
,
rank
,
world_size
):
gather_list
=
[]
if
rank
==
0
:
...
...
@@ -279,4 +282,4 @@ def test_parallel_freq_aware_embed(world_size):
if
__name__
==
'__main__'
:
# test_freq_aware_embed(True)
# test_parallel_freq_aware_embed(2)
test_lfu_strategy
()
\ No newline at end of file
test_lfu_strategy
(
False
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment