Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a19eb809
Unverified
Commit
a19eb809
authored
Sep 15, 2022
by
Jiarui Fang
Committed by
GitHub
Sep 15, 2022
Browse files
[embedding] updates some default parameters
parent
cd5cf2bc
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
18 additions
and
19 deletions
+18
-19
benchmark
benchmark
+0
-1
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+6
-4
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py
...n/parallel/layers/cache_embedding/freq_aware_embedding.py
+6
-6
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
...ache_embedding/parallel_freq_aware_embedding_tablewise.py
+6
-7
examples
examples
+0
-1
No files found.
benchmark
@
9ab77e0e
Compare
9ab77e0e
...
9ab77e0e
Subproject commit 9ab77e0ecc8e4ff480704dac2535b9c8f44f47b2
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
View file @
a19eb809
...
...
@@ -35,7 +35,7 @@ class CachedParamMgr(torch.nn.Module):
self
,
weight
:
torch
.
Tensor
,
cuda_row_num
:
int
=
0
,
buffer_size
:
int
=
50_00
0
,
buffer_size
:
int
=
0
,
pin_weight
:
bool
=
False
,
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
DATASET
,
use_cpu_caching
=
False
,
...
...
@@ -211,7 +211,7 @@ class CachedParamMgr(torch.nn.Module):
freq_value
,
preload_cpu_ids
=
torch
.
topk
(
ids_freq_mapping
,
preload_row_num
,
dim
=
0
,
largest
=
True
)
preload_cuda_row_idxs
=
torch
.
arange
(
preload_row_num
).
to
(
self
.
_cache_dev
)
else
:
preload_cpu_ids
=
torch
.
arange
(
preload_row_num
)
preload_cpu_ids
=
torch
.
arange
(
preload_row_num
,
device
=
self
.
weight
.
device
)
preload_cuda_row_idxs
=
preload_cpu_ids
.
to
(
self
.
_cache_dev
)
if
self
.
buffer_size
>
0
:
...
...
@@ -304,8 +304,10 @@ class CachedParamMgr(torch.nn.Module):
self
.
evict_backlist
=
cpu_row_idxs
with
record_function
(
"(pre-id) get cpu row idxs"
):
comm_cpu_row_idxs
=
cpu_row_idxs
[
torch
.
isin
(
cpu_row_idxs
,
self
.
cached_idx_map
,
assume_unique
=
True
,
invert
=
True
)]
comm_cpu_row_idxs
=
cpu_row_idxs
[
torch
.
isin
(
cpu_row_idxs
,
self
.
cached_idx_map
,
assume_unique
=
True
,
invert
=
True
)]
self
.
num_hits_history
.
append
(
len
(
cpu_row_idxs
)
-
len
(
comm_cpu_row_idxs
))
self
.
num_miss_history
.
append
(
len
(
comm_cpu_row_idxs
))
...
...
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py
View file @
a19eb809
...
...
@@ -30,7 +30,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
cuda_row_num (int, optional): the max number of embedding vector in cuda cache. Defaults to 0.
ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occures in dataset. Defaults to None.
warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7.
buffer_size (int, optional): the max number of vectors in transmitter buffer. Defaults to
50_00
0.
buffer_size (int, optional): the max number of vectors in transmitter buffer.
If set to 0, means do not use the buffer.
Defaults to 0.
pin_weight (bool, optional): pin the cpu weight. Defaults to False.
evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET.
"""
...
...
@@ -51,9 +51,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
cuda_row_num
:
int
=
0
,
ids_freq_mapping
:
Optional
[
Union
[
List
,
torch
.
Tensor
]]
=
None
,
warmup_ratio
:
float
=
0.7
,
buffer_size
:
int
=
50_00
0
,
buffer_size
:
int
=
0
,
pin_weight
:
bool
=
False
,
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
DATASET
):
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
LFU
):
super
(
FreqAwareEmbeddingBag
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
mode
,
include_last_offset
)
...
...
@@ -96,9 +96,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
evict_strategy
=
self
.
evict_strategy
)
self
.
cache_weight_mgr
.
reorder
(
ids_freq_mapping
,
warmup_ratio
)
def
forward
(
self
,
in
dices
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
):
def
forward
(
self
,
in
put
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
):
with
torch
.
no_grad
():
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
in
dices
)
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
in
put
)
embeddings
=
F
.
embedding_bag
(
reorder_ids
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
...
...
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
View file @
a19eb809
...
...
@@ -123,7 +123,6 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
local_per_sample_weights_list
:
List
(
torch
.
Tensor
)
=
[]
offset_pre_end
=
0
# local_offsets trick
for
i
,
handle_table
in
enumerate
(
self
.
assigned_table_list
):
indices_start_position
=
offsets
[
batch_size
*
handle_table
]
if
(
not
self
.
include_last_offset
)
and
(
batch_size
*
(
handle_table
+
1
)
>=
indices
.
shape
[
0
]):
...
...
@@ -162,15 +161,15 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
# till-the-end special case
if
not
self
.
include_last_offset
:
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
batch_size
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
else
:
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
local_offsets_list
.
append
(
local_offsets
)
else
:
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
offset_pre_end
=
local_offsets
[
-
1
]
local_offsets_list
.
append
(
local_offsets
[:
-
1
])
# 3. local_per_sample_weights_list:
...
...
examples
@
757514d2
Compare
757514d2
...
757514d2
Subproject commit 757514d2b1501d3530777cdf567f0a18063acf2d
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment