Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9feee6d0
Unverified
Commit
9feee6d0
authored
Aug 29, 2022
by
Jiarui Fang
Committed by
GitHub
Aug 29, 2022
Browse files
[FAW] LFU initialize with dataset freq (#1513)
parent
1b8fee8e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
23 deletions
+32
-23
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+32
-23
No files found.
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
View file @
9feee6d0
...
@@ -14,7 +14,6 @@ class EvictionStrategy(Enum):
...
@@ -14,7 +14,6 @@ class EvictionStrategy(Enum):
DATASET
=
2
DATASET
=
2
class
CachedParamMgr
(
torch
.
nn
.
Module
):
class
CachedParamMgr
(
torch
.
nn
.
Module
):
"""
"""
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
...
@@ -64,8 +63,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -64,8 +63,7 @@ class CachedParamMgr(torch.nn.Module):
# cache_row_idx -> frequency, freq of the cache rows.
# cache_row_idx -> frequency, freq of the cache rows.
# classic lfu cache. evict the minimal freq value row in cuda cache.
# classic lfu cache. evict the minimal freq value row in cuda cache.
self
.
register_buffer
(
"freq_cnter"
,
self
.
register_buffer
(
"freq_cnter"
,
torch
.
empty
(
self
.
cuda_row_num
,
torch
.
empty
(
self
.
cuda_row_num
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
long
).
fill_
(
sys
.
maxsize
),
dtype
=
torch
.
long
).
fill_
(
sys
.
maxsize
),
persistent
=
False
)
persistent
=
False
)
...
@@ -82,14 +80,14 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -82,14 +80,14 @@ class CachedParamMgr(torch.nn.Module):
"""
"""
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
# find the minimal evict_num freq entries in cached_idx_map
# find the minimal evict_num freq entries in cached_idx_map
_
,
evict_gpu_row_idxs
=
torch
.
topk
(
self
.
freq_cnter
,
evict_num
,
largest
=
False
)
_
,
evict_gpu_row_idxs
=
torch
.
topk
(
self
.
freq_cnter
,
evict_num
,
largest
=
False
)
return
evict_gpu_row_idxs
return
evict_gpu_row_idxs
elif
self
.
_evict_strategy
==
EvictionStrategy
.
DATASET
:
elif
self
.
_evict_strategy
==
EvictionStrategy
.
DATASET
:
# cached_idx_map itself implies the priority of eviction.
# cached_idx_map itself implies the priority of eviction.
# The value of self.cached_idx_map represents cpu_row_idx.
# The value of self.cached_idx_map represents cpu_row_idx.
# The larger it is, the less frequently it will appear in the dataset,
# The larger it is, the less frequently it will appear in the dataset,
# and the higher its eviction priority will be.
# and the higher its eviction priority will be.
_
,
evict_gpu_row_idxs
=
torch
.
topk
(
self
.
cached_idx_map
,
evict_num
,
largest
=
True
)
_
,
evict_gpu_row_idxs
=
torch
.
topk
(
self
.
cached_idx_map
,
evict_num
,
largest
=
True
)
return
evict_gpu_row_idxs
return
evict_gpu_row_idxs
else
:
else
:
raise
TypeError
raise
TypeError
...
@@ -163,8 +161,12 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -163,8 +161,12 @@ class CachedParamMgr(torch.nn.Module):
reorder the weight according to ids' frequency in dataset before training.
reorder the weight according to ids' frequency in dataset before training.
Execute only once before training, also known as warmup phase.
Execute only once before training, also known as warmup phase.
:NOTE If you would like to use the DATASET as the eviction strategy, you must call this function.
Note:
:NOTE If you are use the LFU as the eviction strategy, you can skip this function.
If you would like to use the DATASET as the eviction strategy, you must call this function.
Note:
If you are use the LFU as the eviction strategy, you can skip this function. If you still use this function. It will initialize
The frequency in LFU cache using the dataset statistics.
Args:
Args:
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
...
@@ -182,24 +184,31 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -182,24 +184,31 @@ class CachedParamMgr(torch.nn.Module):
with
Timer
()
as
timer
:
with
Timer
()
as
timer
:
# extract rows from cpu weight
# extract rows from cpu weight
preload_row_ids
=
torch
.
arange
(
preload_row_num
)
preload_row_ids
=
torch
.
arange
(
preload_row_num
)
preload_
slot
_ids
=
preload_row_ids
.
cuda
()
preload_
cuda_row
_id
x
s
=
preload_row_ids
.
cuda
()
if
self
.
buffer_size
>
0
:
if
self
.
buffer_size
>
0
:
self
.
limit_buff_index_copyer
.
index_copy
(
0
,
self
.
limit_buff_index_copyer
.
index_copy
(
0
,
src_index
=
preload_row_ids
,
src_index
=
preload_row_ids
,
tgt_index
=
preload_
slot
_ids
,
tgt_index
=
preload_
cuda_row
_id
x
s
,
src
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
),
src
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
),
tgt
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
))
tgt
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
))
else
:
else
:
preload_rows
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_select
(
0
,
preload_row_ids
).
cuda
()
preload_rows
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_select
(
0
,
preload_row_ids
).
cuda
()
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_copy_
(
0
,
preload_slot_ids
,
preload_rows
)
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_copy_
(
0
,
preload_cuda_row_idxs
,
preload_rows
)
# update auxiliary info
# update auxiliary info
slot_offsets
=
preload_slot_ids
slot_offsets
=
preload_cuda_row_idxs
self
.
cached_idx_map
[
preload_slot_ids
]
=
preload_slot_ids
self
.
cached_idx_map
[
preload_cuda_row_idxs
]
=
preload_cuda_row_idxs
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
.
index_fill_
(
0
,
preload_slot_ids
,
0
)
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
inverted_cached_idx
[
preload_slot_ids
]
=
slot_offsets
# if the ids_freq_mapping is not None, we initialize the embedding row's freq value in LFU as its freq in dataset.
if
ids_freq_mapping
is
None
:
self
.
freq_cnter
.
index_fill_
(
0
,
preload_cuda_row_idxs
,
0
)
else
:
self
.
freq_cnter
.
index_fill_
(
0
,
preload_cuda_row_idxs
,
self
.
idx_map
[
preload_cuda_row_idxs
])
self
.
inverted_cached_idx
[
preload_cuda_row_idxs
]
=
slot_offsets
self
.
_cuda_available_row_num
-=
preload_row_num
self
.
_cuda_available_row_num
-=
preload_row_num
print
(
f
'Cache warmup finished cost
{
timer
.
elapsed
}
sec.'
)
print
(
f
'Cache warmup finished cost
{
timer
.
elapsed
}
sec.'
)
...
@@ -215,7 +224,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -215,7 +224,7 @@ class CachedParamMgr(torch.nn.Module):
self
.
inverted_cached_idx
.
index_fill_
(
0
,
row_ids
,
-
1
)
self
.
inverted_cached_idx
.
index_fill_
(
0
,
row_ids
,
-
1
)
self
.
_cuda_available_row_num
+=
slots
.
numel
()
self
.
_cuda_available_row_num
+=
slots
.
numel
()
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
.
fill_
(
sys
.
maxsize
)
self
.
freq_cnter
.
fill_
(
sys
.
maxsize
)
assert
self
.
_cuda_available_row_num
==
self
.
cuda_row_num
assert
self
.
_cuda_available_row_num
==
self
.
cuda_row_num
assert
torch
.
all
(
self
.
inverted_cached_idx
==
-
1
).
item
()
assert
torch
.
all
(
self
.
inverted_cached_idx
==
-
1
).
item
()
...
@@ -258,7 +267,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -258,7 +267,7 @@ class CachedParamMgr(torch.nn.Module):
torch.Tensor: indices on the cuda_cached_weight.
torch.Tensor: indices on the cuda_cached_weight.
"""
"""
with
record_function
(
"(zhg) get unique indices"
):
with
record_function
(
"(zhg) get unique indices"
):
cpu_row_idxs
,
repeat_times
=
torch
.
unique
(
self
.
idx_map
.
index_select
(
0
,
ids
),
return_counts
=
True
)
cpu_row_idxs
,
repeat_times
=
torch
.
unique
(
self
.
idx_map
.
index_select
(
0
,
ids
),
return_counts
=
True
)
assert
len
(
cpu_row_idxs
)
<=
self
.
cuda_row_num
,
\
assert
len
(
cpu_row_idxs
)
<=
self
.
cuda_row_num
,
\
f
"You move
{
len
(
cpu_row_idxs
)
}
embedding rows from CPU to CUDA. "
\
f
"You move
{
len
(
cpu_row_idxs
)
}
embedding rows from CPU to CUDA. "
\
...
@@ -283,10 +292,10 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -283,10 +292,10 @@ class CachedParamMgr(torch.nn.Module):
gpu_row_idxs
=
self
.
_id_to_cached_cuda_id
(
ids
)
gpu_row_idxs
=
self
.
_id_to_cached_cuda_id
(
ids
)
# update for LFU.
# update for LFU.
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
unique_gpu_row_idxs
=
self
.
inverted_cached_idx
[
cpu_row_idxs
]
unique_gpu_row_idxs
=
self
.
inverted_cached_idx
[
cpu_row_idxs
]
self
.
freq_cnter
.
scatter_add_
(
0
,
unique_gpu_row_idxs
,
repeat_times
)
self
.
freq_cnter
.
scatter_add_
(
0
,
unique_gpu_row_idxs
,
repeat_times
)
return
gpu_row_idxs
return
gpu_row_idxs
def
_reset_comm_stats
(
self
):
def
_reset_comm_stats
(
self
):
...
@@ -363,7 +372,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -363,7 +372,7 @@ class CachedParamMgr(torch.nn.Module):
slot_offsets
=
slots
slot_offsets
=
slots
self
.
cached_idx_map
[
slots
]
=
cpu_row_idxs
self
.
cached_idx_map
[
slots
]
=
cpu_row_idxs
self
.
inverted_cached_idx
.
index_copy_
(
0
,
cpu_row_idxs
,
slot_offsets
)
self
.
inverted_cached_idx
.
index_copy_
(
0
,
cpu_row_idxs
,
slot_offsets
)
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
.
index_fill_
(
0
,
slots
,
0
)
self
.
freq_cnter
.
index_fill_
(
0
,
slots
,
0
)
self
.
_cuda_available_row_num
-=
cpu_row_idxs
.
numel
()
self
.
_cuda_available_row_num
-=
cpu_row_idxs
.
numel
()
self
.
_cpu_to_cuda_elpase
+=
timer
.
elapsed
self
.
_cpu_to_cuda_elpase
+=
timer
.
elapsed
...
@@ -407,7 +416,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -407,7 +416,7 @@ class CachedParamMgr(torch.nn.Module):
# update inverted_cached_idx, min_slot_id is evicted from cuda
# update inverted_cached_idx, min_slot_id is evicted from cuda
self
.
cached_idx_map
[
max_cpu_row_idx
]
=
-
1
self
.
cached_idx_map
[
max_cpu_row_idx
]
=
-
1
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
[
max_cpu_row_idx
]
=
sys
.
maxsize
self
.
freq_cnter
[
max_cpu_row_idx
]
=
sys
.
maxsize
self
.
inverted_cached_idx
[
max_gpu_row_idx
]
=
-
1
self
.
inverted_cached_idx
[
max_gpu_row_idx
]
=
-
1
...
@@ -443,7 +452,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -443,7 +452,7 @@ class CachedParamMgr(torch.nn.Module):
# update the inverted_cached_idx
# update the inverted_cached_idx
self
.
cached_idx_map
[
slot_id
]
=
row_id
self
.
cached_idx_map
[
slot_id
]
=
row_id
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
[
slot_id
]
=
0
self
.
freq_cnter
[
slot_id
]
=
0
self
.
inverted_cached_idx
[
row_id
]
=
slot_offset
self
.
inverted_cached_idx
[
row_id
]
=
slot_offset
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment