Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cde7b8a5
Unverified
Commit
cde7b8a5
authored
Aug 24, 2022
by
Jiarui Fang
Committed by
GitHub
Aug 24, 2022
Browse files
[FAW] init an LFU implementation for FAW (#1488)
parent
32efe8e7
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
112 additions
and
39 deletions
+112
-39
colossalai/nn/parallel/layers/__init__.py
colossalai/nn/parallel/layers/__init__.py
+2
-2
colossalai/nn/parallel/layers/cache_embedding/__init__.py
colossalai/nn/parallel/layers/cache_embedding/__init__.py
+5
-2
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+71
-8
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py
...n/parallel/layers/cache_embedding/freq_aware_embedding.py
+26
-22
tests/test_layers/test_cache_embedding.py
tests/test_layers/test_cache_embedding.py
+8
-5
No files found.
colossalai/nn/parallel/layers/__init__.py
View file @
cde7b8a5
...
...
@@ -3,10 +3,10 @@ from .linear import ColoLinear
from
.embedding
import
ColoEmbedding
from
.module_utils
import
register_colo_module
,
is_colo_module
,
get_colo_module
,
init_colo_module
,
check_colo_module
from
.cache_embedding
import
FreqAwareEmbeddingBag
,
ParallelFreqAwareEmbeddingBag
,
CachedParamMgr
,
LimitBuffIndexCopyer
from
.cache_embedding
import
FreqAwareEmbeddingBag
,
ParallelFreqAwareEmbeddingBag
,
CachedParamMgr
,
LimitBuffIndexCopyer
,
EvictionStrategy
__all__
=
[
'ColoModule'
,
'register_colo_module'
,
'is_colo_module'
,
'get_colo_module'
,
'init_colo_module'
,
'check_colo_module'
,
'ColoLinear'
,
'ColoEmbedding'
,
'FreqAwareEmbeddingBag'
,
'ParallelFreqAwareEmbeddingBag'
,
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
'LimitBuffIndexCopyer'
,
'EvictionStrategy'
]
colossalai/nn/parallel/layers/cache_embedding/__init__.py
View file @
cde7b8a5
from
.cache_mgr
import
CachedParamMgr
from
.cache_mgr
import
CachedParamMgr
,
EvictionStrategy
from
.copyer
import
LimitBuffIndexCopyer
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
.parallel_freq_aware_embedding
import
ParallelFreqAwareEmbeddingBag
__all__
=
[
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'FreqAwareEmbeddingBag'
,
'ParallelFreqAwareEmbeddingBag'
]
__all__
=
[
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'FreqAwareEmbeddingBag'
,
'ParallelFreqAwareEmbeddingBag'
,
'EvictionStrategy'
]
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
View file @
cde7b8a5
...
...
@@ -4,6 +4,12 @@ from torch.profiler import record_function
from
typing
import
List
,
Optional
from
contexttimer
import
Timer
from
.copyer
import
LimitBuffIndexCopyer
from
enum
import
Enum
class
EvictionStrategy
(
Enum
):
LFU
=
1
DATASET
=
2
class
CachedParamMgr
(
torch
.
nn
.
Module
):
...
...
@@ -18,7 +24,8 @@ class CachedParamMgr(torch.nn.Module):
weight
:
torch
.
Tensor
,
cuda_row_num
:
int
=
0
,
buffer_size
:
int
=
50_000
,
pin_weight
=
False
)
->
None
:
pin_weight
=
False
,
evict_strategy
=
EvictionStrategy
.
DATASET
)
->
None
:
super
(
CachedParamMgr
,
self
).
__init__
()
self
.
buffer_size
=
buffer_size
self
.
num_embeddings
,
self
.
embedding_dim
=
weight
.
shape
...
...
@@ -38,6 +45,51 @@ class CachedParamMgr(torch.nn.Module):
self
.
input_id_percent_in_load_chunk
=
[]
self
.
_reset_comm_stats
()
self
.
_evict_strategy
=
evict_strategy
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
# cpu_row_idx -> frequency, freq of the cpu rows.
# evict the minimal freq value row in cuda cache.
self
.
register_buffer
(
"freq_cnter"
,
torch
.
empty
(
self
.
num_embeddings
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
long
).
fill_
(
0
),
persistent
=
False
)
def
_update_freq_cnter
(
self
,
cpu_row_idxs
:
torch
.
Tensor
)
->
None
:
"""_update_freq_cnter
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
Args:
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
"""
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
[
cpu_row_idxs
]
+=
1
def
_find_evict_gpu_idxs
(
self
,
evict_num
:
int
)
->
torch
.
Tensor
:
"""_find_evict_gpu_idxs
Find the gpu idxs to be evicted, according to their freq.
Args:
evict_num (int): how many rows has to be evicted
Returns:
torch.Tensor: a list tensor (1D), contains the gpu_row_idxs.
"""
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
# find the minimal evict_num freq entries in cached_idx_map
evict_gpu_row_idxs
=
torch
.
argsort
(
self
.
freq_cnter
[
self
.
cached_idx_map
])[:
evict_num
]
return
self
.
cached_idx_map
[
evict_gpu_row_idxs
]
elif
self
.
_evict_strategy
==
EvictionStrategy
.
DATASET
:
# cached_idx_map itself implies the priority of eviction.
# The value of self.cached_idx_map represents cpu_row_idx.
# The larger it is, the less frequently it will appear in the dataset,
# and the higher its eviction priority will be.
return
torch
.
argsort
(
self
.
cached_idx_map
,
descending
=
True
)[:
evict_num
]
else
:
raise
TypeError
def
_init_weight
(
self
,
weight
):
if
self
.
cuda_row_num
>
0
:
# Enable cache with introducing auxiliary data structures
...
...
@@ -220,6 +272,10 @@ class CachedParamMgr(torch.nn.Module):
# new ids chunk_offset + offset_in_chunk
with
record_function
(
"(zhg) embed idx -> cache chunk id"
):
gpu_row_idxs
=
self
.
_id_to_cached_cuda_id
(
ids
)
# update for LFU.
self
.
_update_freq_cnter
(
cpu_row_idxs
)
return
gpu_row_idxs
def
_reset_comm_stats
(
self
):
...
...
@@ -234,6 +290,7 @@ class CachedParamMgr(torch.nn.Module):
@
torch
.
no_grad
()
def
_prepare_rows_on_cuda
(
self
,
cpu_row_idxs
:
torch
.
Tensor
)
->
None
:
"""prepare rows in cpu_row_idxs on CUDA memory
Args:
cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA
"""
...
...
@@ -245,7 +302,9 @@ class CachedParamMgr(torch.nn.Module):
invalid_idxs
=
torch
.
nonzero
(
mask_cpu_row_idx
).
squeeze
(
1
)
self
.
cached_idx_map
.
index_fill_
(
0
,
invalid_idxs
,
-
2
)
evict_gpu_row_idxs
=
torch
.
argsort
(
self
.
cached_idx_map
,
descending
=
True
)[:
evict_num
]
evict_gpu_row_idxs
=
self
.
_find_evict_gpu_idxs
(
evict_num
)
self
.
cached_idx_map
.
index_copy_
(
0
,
invalid_idxs
,
backup_idxs
)
evict_info
=
self
.
cached_idx_map
[
evict_gpu_row_idxs
]
...
...
@@ -291,8 +350,16 @@ class CachedParamMgr(torch.nn.Module):
self
.
_cpu_to_cuda_numel
+=
weight_size
# print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
def
_find_free_cuda_row
(
self
)
->
int
:
if
self
.
_cuda_available_row_num
==
0
:
return
-
1
candidates
=
torch
.
nonzero
(
self
.
cached_idx_map
==
-
1
).
squeeze
(
1
)
return
candidates
[
0
].
item
()
def
_evict
(
self
)
->
int
:
"""
deprecated
evict one chunk from cuda to cpu.
Returns:
(int) : the slot id be evicted.
...
...
@@ -329,15 +396,11 @@ class CachedParamMgr(torch.nn.Module):
# self.num_write_back_history[-1] += 1
return
max_cpu_row_idx
def
_find_free_cuda_row
(
self
)
->
int
:
if
self
.
_cuda_available_row_num
==
0
:
return
-
1
candidates
=
torch
.
nonzero
(
self
.
cached_idx_map
==
-
1
).
squeeze
(
1
)
return
candidates
[
0
].
item
()
@
torch
.
no_grad
()
def
_admit
(
self
,
row_id
:
int
):
"""
deprecated
move in row_id to CUDA
Args:
...
...
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py
View file @
cde7b8a5
...
...
@@ -3,14 +3,13 @@ import torch.nn.functional as F
from
typing
import
List
,
Optional
,
Iterator
,
Tuple
from
.base_embedding
import
BaseEmbeddingBag
from
.cache_mgr
import
CachedParamMgr
from
.cache_mgr
import
CachedParamMgr
,
EvictionStrategy
from
torch.nn.parameter
import
Parameter
class
FreqAwareEmbeddingBag
(
BaseEmbeddingBag
):
def
__init__
(
self
,
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
padding_idx
=
None
,
...
...
@@ -28,10 +27,11 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
warmup_ratio
=
0.7
,
buffer_size
=
50_000
,
pin_weight
=
False
,
):
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
DATASET
):
super
(
FreqAwareEmbeddingBag
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
mode
,
include_last_offset
)
self
.
evict_strategy
=
evict_strategy
if
_weight
is
None
:
_weight
=
self
.
_weight_alloc
(
dtype
,
device
)
...
...
@@ -63,7 +63,11 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
warmup_ratio (float): the amount of rows preloaded in cuda cache
"""
self
.
cache_weight_mgr
=
CachedParamMgr
(
weight
,
cuda_row_num
,
buffer_size
,
pin_weight
)
self
.
cache_weight_mgr
=
CachedParamMgr
(
weight
,
cuda_row_num
,
buffer_size
,
pin_weight
,
evict_strategy
=
self
.
evict_strategy
)
self
.
cache_weight_mgr
.
reorder
(
ids_freq_mapping
,
warmup_ratio
)
def
forward
(
self
,
indices
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
):
...
...
tests/test_layers/test_cache_embedding.py
View file @
cde7b8a5
...
...
@@ -12,7 +12,7 @@ from colossalai.utils import free_port
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.tensor
import
ColoParameter
,
ProcessGroup
,
ShardSpec
,
ComputePattern
,
ComputeSpec
,
\
ColoTensor
,
ColoTensorSpec
from
colossalai.nn.parallel.layers
import
CachedParamMgr
,
FreqAwareEmbeddingBag
,
ParallelFreqAwareEmbeddingBag
from
colossalai.nn.parallel.layers
import
CachedParamMgr
,
FreqAwareEmbeddingBag
,
ParallelFreqAwareEmbeddingBag
,
EvictionStrategy
NUM_EMBED
,
EMBED_DIM
=
10
,
8
BATCH_SIZE
=
8
...
...
@@ -41,6 +41,7 @@ def synthesize_1d_sparse_feature(
return
indices
,
offsets
@
pytest
.
mark
.
skip
def
test_cachemgr
():
model
=
torch
.
nn
.
EmbeddingBag
(
10000
,
128
)
# 10 chunks, 5 in cuda
...
...
@@ -98,14 +99,17 @@ def test_reorder_with_freq():
f
"offset in chunk:
{
offset_in_chunk
}
, mgr:
{
mgr_offsets
}
"
def
test_freq_aware_embed
():
@
pytest
.
mark
.
parametrize
(
'use_LFU'
,
[
True
,
False
])
def
test_freq_aware_embed
(
use_LFU
:
bool
):
device
=
torch
.
device
(
'cuda'
,
0
)
evict_strategy
=
EvictionStrategy
.
LFU
if
use_LFU
else
EvictionStrategy
.
DATASET
model
=
FreqAwareEmbeddingBag
(
NUM_EMBED
,
EMBED_DIM
,
mode
=
'mean'
,
include_last_offset
=
True
,
cuda_row_num
=
BATCH_SIZE
*
2
,
ids_freq_mapping
=
None
).
to
(
device
)
ids_freq_mapping
=
None
,
evict_strategy
=
evict_strategy
).
to
(
device
)
assert
model
.
weight
.
shape
[
0
]
==
NUM_EMBED
ref_model
=
torch
.
nn
.
EmbeddingBag
.
from_pretrained
(
model
.
weight
.
detach
().
to
(
device
),
...
...
@@ -231,6 +235,5 @@ def test_parallel_freq_aware_embed(world_size):
if
__name__
==
'__main__'
:
test_cachemgr
()
# test_freq_aware_embed()
test_freq_aware_embed
(
True
)
# test_parallel_freq_aware_embed(2)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment