Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0767f67a
Unverified
Commit
0767f67a
authored
Sep 26, 2022
by
CsRic
Committed by
GitHub
Sep 26, 2022
Browse files
[embedding] isolate cache_op from forward (#1645)
Co-authored-by:
ric
<
mkkt_bkkt@mail.ustc.edu.cn
>
parent
c5d39215
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
19 deletions
+34
-19
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py
...n/parallel/layers/cache_embedding/freq_aware_embedding.py
+6
-2
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
...l/layers/cache_embedding/parallel_freq_aware_embedding.py
+14
-9
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
...ache_embedding/parallel_freq_aware_embedding_tablewise.py
+14
-8
No files found.
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py
View file @
0767f67a
...
...
@@ -64,6 +64,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
cuda_row_num
=
int
(
num_embeddings
*
cache_ratio
)
# configure weight & cache
self
.
_preprocess
(
_weight
,
cuda_row_num
,
ids_freq_mapping
,
warmup_ratio
,
buffer_size
,
pin_weight
)
self
.
cache_op
=
True
def
_weight_alloc
(
self
,
dtype
,
device
):
weight
=
torch
.
empty
(
self
.
num_embeddings
,
self
.
embedding_dim
,
dtype
=
dtype
,
device
=
device
)
...
...
@@ -97,8 +98,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
evict_strategy
=
self
.
evict_strategy
)
self
.
cache_weight_mgr
.
reorder
(
ids_freq_mapping
,
warmup_ratio
)
def
forward
(
self
,
input
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
,
cache_op
=
True
):
if
cache_op
:
def
forward
(
self
,
input
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
):
if
self
.
cache_op
:
with
torch
.
no_grad
():
input
=
self
.
cache_weight_mgr
.
prepare_ids
(
input
)
...
...
@@ -119,6 +120,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
def
parameters
(
self
,
recurse
:
bool
=
True
)
->
Iterator
[
Parameter
]:
yield
self
.
cache_weight_mgr
.
cuda_cached_weight
def
set_cache_op
(
self
,
cache_op
:
bool
=
True
):
self
.
cache_op
=
cache_op
############################# Perf Log ###################################
...
...
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
View file @
0767f67a
...
...
@@ -60,6 +60,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
,
mode
,
include_last_offset
,
dtype
,
device
,
cache_ratio
,
ids_freq_mapping
,
warmup_ratio
,
buffer_size
,
pin_weight
,
evict_strategy
)
self
.
cache_op
=
True
def
_weight_alloc
(
self
,
dtype
,
device
):
weight
=
torch
.
empty
(
self
.
num_embeddings
,
self
.
embedding_dim_per_partition
,
device
=
device
,
dtype
=
dtype
)
...
...
@@ -72,15 +73,16 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
compute_attr
=
ComputePattern
.
TP1D
)
return
ColoTensor
.
from_torch_tensor
(
weight
,
spec
=
colo_tensor_spec
)
def
forward
(
self
,
indices
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
,
scatter_dim
=
0
,
gather_dim
=-
1
,
cache_op
:
bool
=
True
):
if
cache_op
:
def
forward
(
self
,
indices
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
,
scatter_dim
=
0
,
gather_dim
=-
1
,
):
if
self
.
cache_op
:
with
torch
.
no_grad
():
indices
=
self
.
cache_weight_mgr
.
prepare_ids
(
indices
)
output_shard
=
F
.
embedding_bag
(
indices
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
self
.
max_norm
,
...
...
@@ -94,6 +96,9 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
gather_dim
=
gather_dim
)
return
output_full
def
set_cache_op
(
self
,
cache_op
:
bool
=
True
):
self
.
cache_op
=
cache_op
@
classmethod
def
from_pretrained
(
cls
,
...
...
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
View file @
0767f67a
...
...
@@ -81,13 +81,16 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
for
rank
in
self
.
rank_of_tables
:
self
.
embedding_dim_per_rank
[
rank
]
+=
embedding_dim
def
forward
(
self
,
indices
:
torch
.
Tensor
,
offsets
:
torch
.
Tensor
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
,
already_split_along_rank
=
True
,
cache_op
=
True
):
self
.
cache_op
=
True
def
forward
(
self
,
indices
:
torch
.
Tensor
,
offsets
:
torch
.
Tensor
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
,
already_split_along_rank
=
True
,
):
if
not
already_split_along_rank
:
# not recommanded. it takes time.
batch_size
=
(
offsets
.
shape
[
0
])
//
self
.
global_tables_num
...
...
@@ -97,7 +100,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
# recommanded.
batch_size
=
(
offsets
.
shape
[
0
])
//
len
(
self
.
assigned_table_list
)
local_indices
,
local_offsets
,
local_per_sample_weights
=
indices
,
offsets
,
per_sample_weights
if
cache_op
:
if
self
.
cache_op
:
with
torch
.
no_grad
():
indices
=
self
.
cache_weight_mgr
.
prepare_ids
(
local_indices
)
local_output
=
F
.
embedding_bag
(
indices
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
local_offsets
,
...
...
@@ -185,6 +188,9 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
local_per_sample_weights
=
torch
.
cat
(
local_per_sample_weights_list
,
0
)
return
local_indices
,
local_offsets
,
local_per_sample_weights
def
set_cache_op
(
self
,
cache_op
:
bool
=
True
):
self
.
cache_op
=
cache_op
def
print_comm_stats_
(
self
):
self
.
cache_weight_mgr
.
print_comm_stats
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment