Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a389ac4e
Unverified
Commit
a389ac4e
authored
Sep 08, 2022
by
CsRic
Committed by
GitHub
Sep 08, 2022
Browse files
[embedding] cache_embedding small improvement (#1564)
parent
10dd8226
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
13 deletions
+35
-13
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+4
-3
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
...l/layers/cache_embedding/parallel_freq_aware_embedding.py
+2
-3
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
...ache_embedding/parallel_freq_aware_embedding_tablewise.py
+29
-7
No files found.
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
View file @
a389ac4e
...
...
@@ -304,7 +304,8 @@ class CachedParamMgr(torch.nn.Module):
self
.
evict_backlist
=
cpu_row_idxs
with
record_function
(
"(pre-id) get cpu row idxs"
):
comm_cpu_row_idxs
=
cpu_row_idxs
[
torch
.
isin
(
cpu_row_idxs
,
self
.
cached_idx_map
,
invert
=
True
)]
comm_cpu_row_idxs
=
cpu_row_idxs
[
torch
.
isin
(
cpu_row_idxs
,
self
.
cached_idx_map
,
assume_unique
=
True
,
invert
=
True
)]
self
.
num_hits_history
.
append
(
len
(
cpu_row_idxs
)
-
len
(
comm_cpu_row_idxs
))
self
.
num_miss_history
.
append
(
len
(
comm_cpu_row_idxs
))
...
...
@@ -345,7 +346,7 @@ class CachedParamMgr(torch.nn.Module):
evict_num
=
cpu_row_idxs
.
numel
()
-
self
.
cuda_available_row_num
if
evict_num
>
0
:
with
Timer
()
as
timer
:
mask_cpu_row_idx
=
torch
.
isin
(
self
.
cached_idx_map
,
self
.
evict_backlist
)
mask_cpu_row_idx
=
torch
.
isin
(
self
.
cached_idx_map
,
self
.
evict_backlist
,
assume_unique
=
True
)
invalid_idxs
=
torch
.
nonzero
(
mask_cpu_row_idx
).
squeeze
(
1
)
if
self
.
_evict_strategy
==
EvictionStrategy
.
DATASET
:
# mask method.
...
...
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
View file @
a389ac4e
...
...
@@ -75,7 +75,6 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
def
forward
(
self
,
indices
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
,
scatter_dim
=
0
,
gather_dim
=-
1
):
with
torch
.
no_grad
():
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
indices
)
output_shard
=
F
.
embedding_bag
(
reorder_ids
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
...
...
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
View file @
a389ac4e
...
...
@@ -87,6 +87,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
local_per_sample_weights_list
:
List
(
torch
.
Tensor
)
=
[]
offset_pre_end
=
0
# local_offsets trick
for
i
,
handle_table
in
enumerate
(
self
.
assigned_table_list
):
indices_start_position
=
offsets
[
batch_size
*
handle_table
]
if
(
not
self
.
include_last_offset
)
and
(
batch_size
*
(
handle_table
+
1
)
>=
indices
.
shape
[
0
]):
...
...
@@ -94,6 +95,28 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
indices_end_position
=
indices
.
shape
[
0
]
else
:
indices_end_position
=
offsets
[
batch_size
*
(
handle_table
+
1
)]
# alternative approach: reduce malloc
'''
# 1. local_indices_list:
local_indices = indices.narrow(0, indices_start_position, indices_end_position - indices_start_position)
torch.sub(local_indices, self.idx_offset_list[i], out=local_indices)
local_indices_list.append(local_indices)
# 2. local_offsets_list:
if i + 1 == len(self.assigned_table_list):
# till-the-end special case
if not self.include_last_offset:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size)
else:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1)
torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
local_offsets_list.append(local_offsets)
else:
temp_holder = offsets[batch_size * handle_table].item()
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size)
torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
offset_pre_end = offsets[batch_size * (handle_table + 1)] + offset_pre_end - temp_holder
local_offsets_list.append(local_offsets)
'''
# 1. local_indices_list:
local_indices_list
.
append
(
indices
.
narrow
(
0
,
indices_start_position
,
...
...
@@ -103,21 +126,20 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
# till-the-end special case
if
not
self
.
include_last_offset
:
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
batch_size
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
else
:
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
local_offsets_list
.
append
(
local_offsets
)
else
:
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
offset_pre_end
=
local_offsets
[
-
1
]
local_offsets_list
.
append
(
local_offsets
[:
-
1
])
# 3. local_per_sample_weights_list:
if
per_sample_weights
!=
None
:
local_per_sample_weights_list
.
append
(
per_sample_weights
[
indices_start_position
:
indices_end_position
])
local_indices
=
torch
.
cat
(
local_indices_list
,
0
)
local_offsets
=
torch
.
cat
(
local_offsets_list
,
0
)
local_per_sample_weights
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment