Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e57df803
Unverified
Commit
e57df803
authored
Sep 23, 2022
by
Jiarui Fang
Committed by
GitHub
Sep 23, 2022
Browse files
[embeddings] cache option (#1635)
parent
a088022e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
14 deletions
+25
-14
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py
...n/parallel/layers/cache_embedding/freq_aware_embedding.py
+6
-5
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
...l/layers/cache_embedding/parallel_freq_aware_embedding.py
+13
-5
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
...ache_embedding/parallel_freq_aware_embedding_tablewise.py
+6
-4
No files found.
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py
View file @
e57df803
...
@@ -97,12 +97,13 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
...
@@ -97,12 +97,13 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
evict_strategy
=
self
.
evict_strategy
)
evict_strategy
=
self
.
evict_strategy
)
self
.
cache_weight_mgr
.
reorder
(
ids_freq_mapping
,
warmup_ratio
)
self
.
cache_weight_mgr
.
reorder
(
ids_freq_mapping
,
warmup_ratio
)
def
forward
(
self
,
input
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
):
def
forward
(
self
,
input
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
,
cache_op
=
True
):
with
torch
.
no_grad
():
if
cache_op
:
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
input
)
with
torch
.
no_grad
():
input
=
self
.
cache_weight_mgr
.
prepare_ids
(
input
)
embeddings
=
F
.
embedding_bag
(
reorder_ids
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
embeddings
=
F
.
embedding_bag
(
input
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
self
.
max_norm
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
if
shape_hook
is
not
None
:
if
shape_hook
is
not
None
:
embeddings
=
shape_hook
(
embeddings
)
embeddings
=
shape_hook
(
embeddings
)
...
...
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
View file @
e57df803
...
@@ -72,11 +72,19 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
...
@@ -72,11 +72,19 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
compute_attr
=
ComputePattern
.
TP1D
)
compute_attr
=
ComputePattern
.
TP1D
)
return
ColoTensor
.
from_torch_tensor
(
weight
,
spec
=
colo_tensor_spec
)
return
ColoTensor
.
from_torch_tensor
(
weight
,
spec
=
colo_tensor_spec
)
def
forward
(
self
,
indices
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
,
scatter_dim
=
0
,
gather_dim
=-
1
):
def
forward
(
self
,
with
torch
.
no_grad
():
indices
,
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
indices
)
offsets
=
None
,
output_shard
=
F
.
embedding_bag
(
reorder_ids
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
per_sample_weights
=
None
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
shape_hook
=
None
,
scatter_dim
=
0
,
gather_dim
=-
1
,
cache_op
:
bool
=
True
):
if
cache_op
:
with
torch
.
no_grad
():
indices
=
self
.
cache_weight_mgr
.
prepare_ids
(
indices
)
output_shard
=
F
.
embedding_bag
(
indices
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
if
shape_hook
is
not
None
:
if
shape_hook
is
not
None
:
output_shard
=
shape_hook
(
output_shard
)
output_shard
=
shape_hook
(
output_shard
)
...
...
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
View file @
e57df803
...
@@ -86,7 +86,8 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
...
@@ -86,7 +86,8 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
offsets
:
torch
.
Tensor
=
None
,
offsets
:
torch
.
Tensor
=
None
,
per_sample_weights
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
,
shape_hook
=
None
,
already_split_along_rank
=
True
):
already_split_along_rank
=
True
,
cache_op
=
True
):
if
not
already_split_along_rank
:
if
not
already_split_along_rank
:
# not recommanded. it takes time.
# not recommanded. it takes time.
batch_size
=
(
offsets
.
shape
[
0
])
//
self
.
global_tables_num
batch_size
=
(
offsets
.
shape
[
0
])
//
self
.
global_tables_num
...
@@ -96,9 +97,10 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
...
@@ -96,9 +97,10 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
# recommanded.
# recommanded.
batch_size
=
(
offsets
.
shape
[
0
])
//
len
(
self
.
assigned_table_list
)
batch_size
=
(
offsets
.
shape
[
0
])
//
len
(
self
.
assigned_table_list
)
local_indices
,
local_offsets
,
local_per_sample_weights
=
indices
,
offsets
,
per_sample_weights
local_indices
,
local_offsets
,
local_per_sample_weights
=
indices
,
offsets
,
per_sample_weights
with
torch
.
no_grad
():
if
cache_op
:
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
local_indices
)
with
torch
.
no_grad
():
local_output
=
F
.
embedding_bag
(
reorder_ids
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
local_offsets
,
indices
=
self
.
cache_weight_mgr
.
prepare_ids
(
local_indices
)
local_output
=
F
.
embedding_bag
(
indices
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
local_offsets
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
local_per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
local_per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
local_output
=
torch
.
cat
(
local_output
.
split
(
batch_size
),
1
)
local_output
=
torch
.
cat
(
local_output
.
split
(
batch_size
),
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment