Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a389ac4e
Unverified
Commit
a389ac4e
authored
Sep 08, 2022
by
CsRic
Committed by
GitHub
Sep 08, 2022
Browse files
[embedding] cache_embedding small improvement (#1564)
parent
10dd8226
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
13 deletions
+35
-13
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+4
-3
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
...l/layers/cache_embedding/parallel_freq_aware_embedding.py
+2
-3
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
...ache_embedding/parallel_freq_aware_embedding_tablewise.py
+29
-7
No files found.
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
View file @
a389ac4e
...
@@ -178,7 +178,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -178,7 +178,7 @@ class CachedParamMgr(torch.nn.Module):
"""reorder
"""reorder
reorder the weight according to ids' frequency in dataset before training.
reorder the weight according to ids' frequency in dataset before training.
Execute only once before training, also known as warmup phase.
Execute only once before training, also known as warmup phase.
Note:
Note:
If you would like to use the DATASET as the eviction strategy, you must call this function.
If you would like to use the DATASET as the eviction strategy, you must call this function.
...
@@ -304,7 +304,8 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -304,7 +304,8 @@ class CachedParamMgr(torch.nn.Module):
self
.
evict_backlist
=
cpu_row_idxs
self
.
evict_backlist
=
cpu_row_idxs
with
record_function
(
"(pre-id) get cpu row idxs"
):
with
record_function
(
"(pre-id) get cpu row idxs"
):
comm_cpu_row_idxs
=
cpu_row_idxs
[
torch
.
isin
(
cpu_row_idxs
,
self
.
cached_idx_map
,
invert
=
True
)]
comm_cpu_row_idxs
=
cpu_row_idxs
[
torch
.
isin
(
cpu_row_idxs
,
self
.
cached_idx_map
,
assume_unique
=
True
,
invert
=
True
)]
self
.
num_hits_history
.
append
(
len
(
cpu_row_idxs
)
-
len
(
comm_cpu_row_idxs
))
self
.
num_hits_history
.
append
(
len
(
cpu_row_idxs
)
-
len
(
comm_cpu_row_idxs
))
self
.
num_miss_history
.
append
(
len
(
comm_cpu_row_idxs
))
self
.
num_miss_history
.
append
(
len
(
comm_cpu_row_idxs
))
...
@@ -345,7 +346,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -345,7 +346,7 @@ class CachedParamMgr(torch.nn.Module):
evict_num
=
cpu_row_idxs
.
numel
()
-
self
.
cuda_available_row_num
evict_num
=
cpu_row_idxs
.
numel
()
-
self
.
cuda_available_row_num
if
evict_num
>
0
:
if
evict_num
>
0
:
with
Timer
()
as
timer
:
with
Timer
()
as
timer
:
mask_cpu_row_idx
=
torch
.
isin
(
self
.
cached_idx_map
,
self
.
evict_backlist
)
mask_cpu_row_idx
=
torch
.
isin
(
self
.
cached_idx_map
,
self
.
evict_backlist
,
assume_unique
=
True
)
invalid_idxs
=
torch
.
nonzero
(
mask_cpu_row_idx
).
squeeze
(
1
)
invalid_idxs
=
torch
.
nonzero
(
mask_cpu_row_idx
).
squeeze
(
1
)
if
self
.
_evict_strategy
==
EvictionStrategy
.
DATASET
:
if
self
.
_evict_strategy
==
EvictionStrategy
.
DATASET
:
# mask method.
# mask method.
...
...
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
View file @
a389ac4e
...
@@ -75,7 +75,6 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
...
@@ -75,7 +75,6 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
def
forward
(
self
,
indices
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
,
scatter_dim
=
0
,
gather_dim
=-
1
):
def
forward
(
self
,
indices
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
,
scatter_dim
=
0
,
gather_dim
=-
1
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
indices
)
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
indices
)
output_shard
=
F
.
embedding_bag
(
reorder_ids
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
output_shard
=
F
.
embedding_bag
(
reorder_ids
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
...
@@ -124,6 +123,6 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
...
@@ -124,6 +123,6 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
def
print_comm_stats_
(
self
):
def
print_comm_stats_
(
self
):
self
.
cache_weight_mgr
.
print_comm_stats
()
self
.
cache_weight_mgr
.
print_comm_stats
()
def
element_size
(
self
):
def
element_size
(
self
):
return
self
.
weight
.
element_size
()
return
self
.
weight
.
element_size
()
\ No newline at end of file
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
View file @
a389ac4e
...
@@ -87,6 +87,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
...
@@ -87,6 +87,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
local_per_sample_weights_list
:
List
(
torch
.
Tensor
)
=
[]
local_per_sample_weights_list
:
List
(
torch
.
Tensor
)
=
[]
offset_pre_end
=
0
# local_offsets trick
offset_pre_end
=
0
# local_offsets trick
for
i
,
handle_table
in
enumerate
(
self
.
assigned_table_list
):
for
i
,
handle_table
in
enumerate
(
self
.
assigned_table_list
):
indices_start_position
=
offsets
[
batch_size
*
handle_table
]
indices_start_position
=
offsets
[
batch_size
*
handle_table
]
if
(
not
self
.
include_last_offset
)
and
(
batch_size
*
(
handle_table
+
1
)
>=
indices
.
shape
[
0
]):
if
(
not
self
.
include_last_offset
)
and
(
batch_size
*
(
handle_table
+
1
)
>=
indices
.
shape
[
0
]):
...
@@ -94,6 +95,28 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
...
@@ -94,6 +95,28 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
indices_end_position
=
indices
.
shape
[
0
]
indices_end_position
=
indices
.
shape
[
0
]
else
:
else
:
indices_end_position
=
offsets
[
batch_size
*
(
handle_table
+
1
)]
indices_end_position
=
offsets
[
batch_size
*
(
handle_table
+
1
)]
# alternative approach: reduce malloc
'''
# 1. local_indices_list:
local_indices = indices.narrow(0, indices_start_position, indices_end_position - indices_start_position)
torch.sub(local_indices, self.idx_offset_list[i], out=local_indices)
local_indices_list.append(local_indices)
# 2. local_offsets_list:
if i + 1 == len(self.assigned_table_list):
# till-the-end special case
if not self.include_last_offset:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size)
else:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1)
torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
local_offsets_list.append(local_offsets)
else:
temp_holder = offsets[batch_size * handle_table].item()
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size)
torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
offset_pre_end = offsets[batch_size * (handle_table + 1)] + offset_pre_end - temp_holder
local_offsets_list.append(local_offsets)
'''
# 1. local_indices_list:
# 1. local_indices_list:
local_indices_list
.
append
(
local_indices_list
.
append
(
indices
.
narrow
(
0
,
indices_start_position
,
indices
.
narrow
(
0
,
indices_start_position
,
...
@@ -103,21 +126,20 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
...
@@ -103,21 +126,20 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
# till-the-end special case
# till-the-end special case
if
not
self
.
include_last_offset
:
if
not
self
.
include_last_offset
:
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
batch_size
).
add
(
offset_pre_end
-
offsets
[
batch_size
(
handle_table
)])
*
(
handle_table
)])
else
:
else
:
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
+
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
local_offsets_list
.
append
(
local_offsets
)
local_offsets_list
.
append
(
local_offsets
)
else
:
else
:
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
+
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
offset_pre_end
=
local_offsets
[
-
1
]
offset_pre_end
=
local_offsets
[
-
1
]
local_offsets_list
.
append
(
local_offsets
[:
-
1
])
local_offsets_list
.
append
(
local_offsets
[:
-
1
])
# 3. local_per_sample_weights_list:
# 3. local_per_sample_weights_list:
if
per_sample_weights
!=
None
:
if
per_sample_weights
!=
None
:
local_per_sample_weights_list
.
append
(
per_sample_weights
[
indices_start_position
:
indices_end_position
])
local_per_sample_weights_list
.
append
(
per_sample_weights
[
indices_start_position
:
indices_end_position
])
local_indices
=
torch
.
cat
(
local_indices_list
,
0
)
local_indices
=
torch
.
cat
(
local_indices_list
,
0
)
local_offsets
=
torch
.
cat
(
local_offsets_list
,
0
)
local_offsets
=
torch
.
cat
(
local_offsets_list
,
0
)
local_per_sample_weights
=
None
local_per_sample_weights
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment