Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
964123ae
Unverified
Commit
964123ae
authored
Sep 05, 2022
by
CsRic
Committed by
GitHub
Sep 05, 2022
Browse files
[embedding] freq_aware_embedding: add small functions for caller application (#1537)
parent
70129603
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
214 additions
and
46 deletions
+214
-46
colossalai/nn/parallel/layers/__init__.py
colossalai/nn/parallel/layers/__init__.py
+3
-2
colossalai/nn/parallel/layers/cache_embedding/__init__.py
colossalai/nn/parallel/layers/cache_embedding/__init__.py
+2
-2
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
...l/layers/cache_embedding/parallel_freq_aware_embedding.py
+6
-0
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
...ache_embedding/parallel_freq_aware_embedding_tablewise.py
+178
-23
tests/test_layers/test_cache_embedding.py
tests/test_layers/test_cache_embedding.py
+25
-19
No files found.
colossalai/nn/parallel/layers/__init__.py
View file @
964123ae
...
@@ -4,10 +4,11 @@ from .embedding import ColoEmbedding
...
@@ -4,10 +4,11 @@ from .embedding import ColoEmbedding
from
.module_utils
import
register_colo_module
,
is_colo_module
,
get_colo_module
,
init_colo_module
,
check_colo_module
from
.module_utils
import
register_colo_module
,
is_colo_module
,
get_colo_module
,
init_colo_module
,
check_colo_module
from
.cache_embedding
import
FreqAwareEmbeddingBag
,
ParallelFreqAwareEmbeddingBag
,
CachedParamMgr
,
LimitBuffIndexCopyer
,
EvictionStrategy
,
\
from
.cache_embedding
import
FreqAwareEmbeddingBag
,
ParallelFreqAwareEmbeddingBag
,
CachedParamMgr
,
LimitBuffIndexCopyer
,
EvictionStrategy
,
\
ParallelFreqAwareEmbeddingBagTablewise
,
TablewiseEmbeddingBagConfig
ParallelFreqAwareEmbeddingBagTablewise
,
TablewiseEmbeddingBagConfig
,
ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
__all__
=
[
__all__
=
[
'ColoModule'
,
'register_colo_module'
,
'is_colo_module'
,
'get_colo_module'
,
'init_colo_module'
,
'check_colo_module'
,
'ColoModule'
,
'register_colo_module'
,
'is_colo_module'
,
'get_colo_module'
,
'init_colo_module'
,
'check_colo_module'
,
'ColoLinear'
,
'ColoEmbedding'
,
'FreqAwareEmbeddingBag'
,
'ParallelFreqAwareEmbeddingBag'
,
'CachedParamMgr'
,
'ColoLinear'
,
'ColoEmbedding'
,
'FreqAwareEmbeddingBag'
,
'ParallelFreqAwareEmbeddingBag'
,
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'EvictionStrategy'
,
'ParallelFreqAwareEmbeddingBagTablewise'
,
'TablewiseEmbeddingBagConfig'
'LimitBuffIndexCopyer'
,
'EvictionStrategy'
,
'ParallelFreqAwareEmbeddingBagTablewise'
,
'TablewiseEmbeddingBagConfig'
,
'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache'
]
]
colossalai/nn/parallel/layers/cache_embedding/__init__.py
View file @
964123ae
...
@@ -2,8 +2,8 @@ from .cache_mgr import CachedParamMgr, EvictionStrategy
...
@@ -2,8 +2,8 @@ from .cache_mgr import CachedParamMgr, EvictionStrategy
from
.copyer
import
LimitBuffIndexCopyer
from
.copyer
import
LimitBuffIndexCopyer
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
.parallel_freq_aware_embedding
import
ParallelFreqAwareEmbeddingBag
from
.parallel_freq_aware_embedding
import
ParallelFreqAwareEmbeddingBag
from
.parallel_freq_aware_embedding_tablewise
import
ParallelFreqAwareEmbeddingBagTablewise
,
TablewiseEmbeddingBagConfig
from
.parallel_freq_aware_embedding_tablewise
import
ParallelFreqAwareEmbeddingBagTablewise
,
TablewiseEmbeddingBagConfig
,
ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
__all__
=
[
__all__
=
[
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'FreqAwareEmbeddingBag'
,
'ParallelFreqAwareEmbeddingBag'
,
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'FreqAwareEmbeddingBag'
,
'ParallelFreqAwareEmbeddingBag'
,
'EvictionStrategy'
,
'ParallelFreqAwareEmbeddingBagTablewise'
,
'TablewiseEmbeddingBagConfig'
'EvictionStrategy'
,
'ParallelFreqAwareEmbeddingBagTablewise'
,
'TablewiseEmbeddingBagConfig'
,
'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache'
]
]
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
View file @
964123ae
...
@@ -121,3 +121,9 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
...
@@ -121,3 +121,9 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
buffer_size
=
buffer_size
)
buffer_size
=
buffer_size
)
embedding_bag
.
cache_weight_mgr
.
cuda_cached_weight
.
requires_grad_
=
not
freeze
embedding_bag
.
cache_weight_mgr
.
cuda_cached_weight
.
requires_grad_
=
not
freeze
return
embedding_bag
return
embedding_bag
def
print_comm_stats_
(
self
):
self
.
cache_weight_mgr
.
print_comm_stats
()
def
element_size
(
self
):
return
self
.
weight
.
element_size
()
\ No newline at end of file
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
View file @
964123ae
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.profiler
import
record_function
from
typing
import
List
from
typing
import
List
import
abc
import
abc
import
torch.nn.functional
as
F
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
colossalai.tensor
import
ProcessGroup
from
colossalai.tensor
import
ProcessGroup
...
@@ -38,7 +39,137 @@ class TablewiseEmbeddingBagConfig:
...
@@ -38,7 +39,137 @@ class TablewiseEmbeddingBagConfig:
self
.
name
=
name
self
.
name
=
name
class
ParallelFreqAwareEmbeddingBagTablewise
(
abc
.
ABC
,
nn
.
Module
):
class
ParallelFreqAwareEmbeddingBagTablewise
(
FreqAwareEmbeddingBag
):
"""
all tables assigned to this class instance are managed by a single FreqAwareEmbeddingBag.
Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight.
"""
def
__init__
(
self
,
embedding_bag_config_list
:
List
[
TablewiseEmbeddingBagConfig
],
embedding_dim
:
int
,
padding_idx
=
None
,
max_norm
=
None
,
norm_type
=
2.
,
scale_grad_by_freq
=
False
,
sparse
=
False
,
_weight
=
None
,
mode
=
'mean'
,
include_last_offset
=
False
,
dtype
=
None
,
device
=
None
,
cuda_row_num
=
0
,
warmup_ratio
=
0.7
,
buffer_size
=
50_000
,
pin_weight
=
False
,
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
LFU
):
self
.
rank
=
dist
.
get_rank
()
self
.
world_size
=
dist
.
get_world_size
()
self
.
rank_of_tables
=
[
config
.
assigned_rank
for
config
in
embedding_bag_config_list
]
self
.
global_table_num_embeddings_list
=
[
config
.
num_embeddings
for
config
in
embedding_bag_config_list
]
self
.
global_tables_num
=
len
(
embedding_bag_config_list
)
self
.
global_tables_offsets
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
self
.
global_table_num_embeddings_list
),
0
).
cuda
()
self
.
assigned_table_list
:
List
[
int
]
=
[]
self
.
pg
=
ProcessGroup
(
tp_degree
=
self
.
world_size
)
self
.
num_embeddings
=
0
for
i
,
rank
in
enumerate
(
self
.
rank_of_tables
):
if
rank
==
self
.
rank
:
self
.
assigned_table_list
.
append
(
i
)
self
.
num_embeddings
+=
self
.
global_table_num_embeddings_list
[
i
]
self
.
include_last_offset
=
include_last_offset
ids_freq_mapping
=
[]
for
config
in
embedding_bag_config_list
:
if
config
.
assigned_rank
==
self
.
rank
:
if
config
.
ids_freq_mapping
!=
None
:
ids_freq_mapping
.
extend
(
config
.
ids_freq_mapping
)
else
:
ids_freq_mapping
=
None
break
# table-associate cache
super
(
ParallelFreqAwareEmbeddingBagTablewise
,
self
).
__init__
(
self
.
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
,
mode
,
include_last_offset
,
dtype
,
device
,
cuda_row_num
,
ids_freq_mapping
,
warmup_ratio
,
buffer_size
,
pin_weight
,
evict_strategy
)
# for assigned tables reconnection:
self
.
idx_offset_list
=
[]
offset_cumsum
=
0
for
table_i
,
table_num_embeddings
in
enumerate
(
self
.
global_table_num_embeddings_list
):
if
self
.
rank_of_tables
[
table_i
]
==
self
.
rank
:
self
.
idx_offset_list
.
append
(
offset_cumsum
)
else
:
offset_cumsum
+=
table_num_embeddings
# prepare list shape for all_to_all output
self
.
embedding_dim_per_rank
=
[
0
for
i
in
range
(
self
.
world_size
)]
for
rank
in
self
.
rank_of_tables
:
self
.
embedding_dim_per_rank
[
rank
]
+=
embedding_dim
def
forward
(
self
,
indices
:
torch
.
Tensor
,
offsets
:
torch
.
Tensor
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
):
batch_size
=
(
offsets
.
shape
[
0
])
//
self
.
global_tables_num
local_indices_list
:
List
(
torch
.
Tensor
)
=
[]
local_offsets_list
:
List
(
torch
.
Tensor
)
=
[]
if
per_sample_weights
!=
None
:
local_per_sample_weights_list
:
List
(
torch
.
Tensor
)
=
[]
offset_pre_end
=
0
# local_offsets trick
for
i
,
handle_table
in
enumerate
(
self
.
assigned_table_list
):
indices_start_position
=
offsets
[
batch_size
*
handle_table
]
if
(
not
self
.
include_last_offset
)
and
(
batch_size
*
(
handle_table
+
1
)
>=
indices
.
shape
[
0
]):
# till-the-end special case
indices_end_position
=
indices
.
shape
[
0
]
else
:
indices_end_position
=
offsets
[
batch_size
*
(
handle_table
+
1
)]
# 1. local_indices_list:
local_indices_list
.
append
(
indices
.
narrow
(
0
,
indices_start_position
,
indices_end_position
-
indices_start_position
).
sub
(
self
.
idx_offset_list
[
i
]))
# 2. local_offsets_list:
if
i
+
1
==
len
(
self
.
assigned_table_list
):
# till-the-end special case
if
not
self
.
include_last_offset
:
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
else
:
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
local_offsets_list
.
append
(
local_offsets
)
else
:
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
offset_pre_end
=
local_offsets
[
-
1
]
local_offsets_list
.
append
(
local_offsets
[:
-
1
])
# 3. local_per_sample_weights_list:
if
per_sample_weights
!=
None
:
local_per_sample_weights_list
.
append
(
per_sample_weights
[
indices_start_position
:
indices_end_position
])
local_indices
=
torch
.
cat
(
local_indices_list
,
0
)
local_offsets
=
torch
.
cat
(
local_offsets_list
,
0
)
local_per_sample_weights
=
None
if
per_sample_weights
!=
None
:
local_per_sample_weights
=
torch
.
cat
(
local_per_sample_weights_list
,
0
)
with
torch
.
no_grad
():
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
local_indices
)
local_output
=
F
.
embedding_bag
(
reorder_ids
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
local_offsets
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
local_per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
local_output
=
torch
.
cat
(
local_output
.
split
(
batch_size
),
1
)
remains
=
batch_size
%
self
.
world_size
scatter_strides
=
[
batch_size
//
self
.
world_size
+
int
(
i
<
remains
)
for
i
in
range
(
self
.
world_size
)]
output_full
=
dual_all_to_all_tablewise
(
local_output
,
self
.
pg
,
scatter_strides
,
self
.
embedding_dim_per_rank
)
if
shape_hook
is
not
None
:
output_full
=
shape_hook
(
output_full
)
return
output_full
def
print_comm_stats_
(
self
):
self
.
cache_weight_mgr
.
print_comm_stats
()
def
element_size
(
self
):
return
self
.
weight
.
element_size
()
class
ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
(
abc
.
ABC
,
nn
.
Module
):
"""
"""
every table assigned to this class instance is managed by a FreqAwareEmbeddingBag.
every table assigned to this class instance is managed by a FreqAwareEmbeddingBag.
"""
"""
...
@@ -58,7 +189,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
...
@@ -58,7 +189,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
warmup_ratio
=
0.7
,
warmup_ratio
=
0.7
,
pin_weight
=
False
,
pin_weight
=
False
,
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
LFU
):
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
LFU
):
super
(
ParallelFreqAwareEmbeddingBagTablewise
,
self
).
__init__
()
super
(
ParallelFreqAwareEmbeddingBagTablewise
SpiltCache
,
self
).
__init__
()
self
.
rank
=
dist
.
get_rank
()
self
.
rank
=
dist
.
get_rank
()
self
.
world_size
=
dist
.
get_world_size
()
self
.
world_size
=
dist
.
get_world_size
()
self
.
rank_of_tables
=
[
config
.
assigned_rank
for
config
in
embedding_bag_config_list
]
self
.
rank_of_tables
=
[
config
.
assigned_rank
for
config
in
embedding_bag_config_list
]
...
@@ -109,26 +240,32 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
...
@@ -109,26 +240,32 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
batch_size
=
(
offsets
.
shape
[
0
])
//
self
.
global_tables_num
batch_size
=
(
offsets
.
shape
[
0
])
//
self
.
global_tables_num
local_output_list
=
[]
local_output_list
=
[]
for
i
,
handle_table
in
enumerate
(
self
.
assigned_table_list
):
for
i
,
handle_table
in
enumerate
(
self
.
assigned_table_list
):
indices_start_position
=
offsets
[
batch_size
*
handle_table
]
with
record_function
(
"(tablewise) prepare indices and offsets"
):
if
(
not
self
.
include_last_offset
)
and
(
batch_size
*
(
handle_table
+
1
)
>=
indices
.
shape
[
0
]):
with
record_function
(
"part 1"
):
# till the end special case
indices_start_position
=
offsets
[
batch_size
*
handle_table
]
indices_end_position
=
indices
.
shape
[
0
]
if
(
not
self
.
include_last_offset
)
and
(
batch_size
*
(
handle_table
+
1
)
>=
indices
.
shape
[
0
]):
else
:
# till the end special case
indices_end_position
=
offsets
[
batch_size
*
(
handle_table
+
1
)]
indices_end_position
=
indices
.
shape
[
0
]
else
:
local_indices
=
indices
[
indices_start_position
:
indices_end_position
]
-
\
indices_end_position
=
offsets
[
batch_size
*
(
handle_table
+
1
)]
self
.
global_tables_offsets
[
handle_table
]
with
record_function
(
"part 2"
):
if
self
.
include_last_offset
:
# local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table]
local_offsets
=
offsets
[
batch_size
*
handle_table
:
batch_size
*
local_indices
=
indices
.
narrow
(
0
,
indices_start_position
,
indices_end_position
(
handle_table
+
1
)
+
1
]
-
offsets
[
batch_size
*
(
handle_table
)]
-
indices_start_position
).
sub
(
self
.
global_tables_offsets
[
handle_table
])
else
:
if
self
.
include_last_offset
:
local_offsets
=
offsets
[
batch_size
*
handle_table
:
batch_size
*
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)]
(
handle_table
+
1
)]
-
offsets
[
batch_size
*
(
handle_table
)]
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
local_per_sample_weights
=
None
batch_size
+
1
).
sub
(
offsets
[
batch_size
*
(
handle_table
)])
if
per_sample_weights
!=
None
:
else
:
local_per_sample_weights
=
per_sample_weights
[
indices_start_position
:
indices_end_position
]
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)]
local_output_list
.
append
(
self
.
freq_aware_embedding_bag_list
[
i
](
local_indices
,
local_offsets
,
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
local_per_sample_weights
))
batch_size
).
sub
(
offsets
[
batch_size
*
(
handle_table
)])
local_per_sample_weights
=
None
if
per_sample_weights
!=
None
:
local_per_sample_weights
=
per_sample_weights
[
indices_start_position
:
indices_end_position
]
with
record_function
(
"(tablewise) tablewise forward"
):
local_output_list
.
append
(
self
.
freq_aware_embedding_bag_list
[
i
](
local_indices
,
local_offsets
,
local_per_sample_weights
))
# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
local_output
=
torch
.
cat
(
local_output_list
,
1
)
local_output
=
torch
.
cat
(
local_output_list
,
1
)
...
@@ -140,3 +277,21 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
...
@@ -140,3 +277,21 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
if
shape_hook
is
not
None
:
if
shape_hook
is
not
None
:
output_full
=
shape_hook
(
output_full
)
output_full
=
shape_hook
(
output_full
)
return
output_full
return
output_full
def
element_size
(
self
):
if
len
(
self
.
assigned_table_list
)
==
0
:
return
0
return
self
.
freq_aware_embedding_bag_list
[
0
].
cache_weight_mgr
.
weight
.
element_size
()
def
print_comm_stats_
(
self
):
cuda_to_cpu_elem_num
=
0
cpu_to_cuda_elem_num
=
0
for
freq_aware_embedding_bag
in
self
.
freq_aware_embedding_bag_list
:
cuda_to_cpu_elem_num
+=
freq_aware_embedding_bag
.
cache_weight_mgr
.
_cuda_to_cpu_numel
cpu_to_cuda_elem_num
+=
freq_aware_embedding_bag
.
cache_weight_mgr
.
_cpu_to_cuda_numel
print
(
f
"CUDA->CPU num:
{
cuda_to_cpu_elem_num
/
1e6
}
M elem"
)
print
(
f
"CPU->CUDA num:
{
cpu_to_cuda_elem_num
/
1e6
}
M elem"
)
tests/test_layers/test_cache_embedding.py
View file @
964123ae
...
@@ -13,7 +13,7 @@ from colossalai.testing import rerun_if_address_is_in_use
...
@@ -13,7 +13,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.tensor
import
ColoParameter
,
ProcessGroup
,
ShardSpec
,
ComputePattern
,
ComputeSpec
,
\
from
colossalai.tensor
import
ColoParameter
,
ProcessGroup
,
ShardSpec
,
ComputePattern
,
ComputeSpec
,
\
ColoTensor
,
ColoTensorSpec
ColoTensor
,
ColoTensorSpec
from
colossalai.nn.parallel.layers
import
CachedParamMgr
,
FreqAwareEmbeddingBag
,
ParallelFreqAwareEmbeddingBag
,
EvictionStrategy
,
\
from
colossalai.nn.parallel.layers
import
CachedParamMgr
,
FreqAwareEmbeddingBag
,
ParallelFreqAwareEmbeddingBag
,
EvictionStrategy
,
\
ParallelFreqAwareEmbeddingBagTablewise
,
TablewiseEmbeddingBagConfig
ParallelFreqAwareEmbeddingBagTablewise
,
TablewiseEmbeddingBagConfig
,
ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
from
typing
import
List
from
typing
import
List
NUM_EMBED
,
EMBED_DIM
=
10
,
8
NUM_EMBED
,
EMBED_DIM
=
10
,
8
...
@@ -209,9 +209,10 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
...
@@ -209,9 +209,10 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
# initialize weight
# initialize weight
# 3 feature tables. idx: 0~5, 6~10, 11~17
# 3 feature tables. idx: 0~5, 6~10, 11~17
weight_table1
=
torch
.
rand
(
6
,
5
)
weight_tables
=
torch
.
rand
(
18
,
5
)
weight_table2
=
torch
.
rand
(
5
,
5
)
weight_table1
=
weight_tables
[
0
:
6
]
weight_table3
=
torch
.
rand
(
7
,
5
)
weight_table2
=
weight_tables
[
6
:
11
]
weight_table3
=
weight_tables
[
11
:
18
]
embedding_bag_config_list
:
List
[
TablewiseEmbeddingBagConfig
]
=
[]
embedding_bag_config_list
:
List
[
TablewiseEmbeddingBagConfig
]
=
[]
embedding_bag_config_list
.
append
(
TablewiseEmbeddingBagConfig
(
embedding_bag_config_list
.
append
(
TablewiseEmbeddingBagConfig
(
num_embeddings
=
6
,
cuda_row_num
=
4
,
assigned_rank
=
0
,
initial_weight
=
weight_table1
.
clone
().
detach
().
cpu
()))
num_embeddings
=
6
,
cuda_row_num
=
4
,
assigned_rank
=
0
,
initial_weight
=
weight_table1
.
clone
().
detach
().
cpu
()))
...
@@ -219,14 +220,20 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
...
@@ -219,14 +220,20 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
num_embeddings
=
5
,
cuda_row_num
=
4
,
assigned_rank
=
0
,
initial_weight
=
weight_table2
.
clone
().
detach
().
cpu
()))
num_embeddings
=
5
,
cuda_row_num
=
4
,
assigned_rank
=
0
,
initial_weight
=
weight_table2
.
clone
().
detach
().
cpu
()))
embedding_bag_config_list
.
append
(
TablewiseEmbeddingBagConfig
(
embedding_bag_config_list
.
append
(
TablewiseEmbeddingBagConfig
(
num_embeddings
=
7
,
cuda_row_num
=
4
,
assigned_rank
=
1
,
initial_weight
=
weight_table3
.
clone
().
detach
().
cpu
()))
num_embeddings
=
7
,
cuda_row_num
=
4
,
assigned_rank
=
1
,
initial_weight
=
weight_table3
.
clone
().
detach
().
cpu
()))
if
rank
==
0
:
_weight
=
torch
.
cat
([
weight_table1
,
weight_table2
],
0
)
else
:
_weight
=
weight_table3
model
=
ParallelFreqAwareEmbeddingBagTablewise
(
model
=
ParallelFreqAwareEmbeddingBagTablewise
(
embedding_bag_config_list
,
embedding_bag_config_list
,
embedding_dim
=
5
,
embedding_dim
=
5
,
_weight
=
_weight
,
include_last_offset
=
True
,
cuda_row_num
=
8
,
buffer_size
=
0
,
evict_strategy
=
EvictionStrategy
.
LFU
,
evict_strategy
=
EvictionStrategy
.
LFU
,
include_last_offset
=
True
)
)
#
demo
explain
:
# explain
'''
'''
batch feature 1 feature 2 feature 3
batch feature 1 feature 2 feature 3
input0 [1,2,3] [6,7] []
input0 [1,2,3] [6,7] []
...
@@ -244,28 +251,27 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
...
@@ -244,28 +251,27 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
fake_grad
=
rand_grad
[
0
:
2
]
fake_grad
=
rand_grad
[
0
:
2
]
else
:
else
:
fake_grad
=
rand_grad
[
2
:]
fake_grad
=
rand_grad
[
2
:]
res
.
backward
(
fake_grad
)
res
.
backward
(
fake_grad
)
optimizer
.
step
()
optimizer
.
step
()
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
# check correctness
on weight_table2
# check correctness
if
rank
==
0
:
if
rank
==
0
:
ref_model
=
torch
.
nn
.
EmbeddingBag
.
from_pretrained
(
weight_table
2
.
detach
().
clone
(),
ref_model
=
torch
.
nn
.
EmbeddingBag
.
from_pretrained
(
weight_table
s
.
detach
().
clone
(),
include_last_offset
=
True
,
include_last_offset
=
True
,
freeze
=
False
).
to
(
device
)
freeze
=
False
).
to
(
device
)
ref_optimizer
=
torch
.
optim
.
SGD
(
ref_model
.
parameters
(),
lr
=
1e-2
)
ref_optimizer
=
torch
.
optim
.
SGD
(
ref_model
.
parameters
(),
lr
=
1e-2
)
ref_grad
=
rand_grad
[:,
5
:
10
]
ref_fake_grad
=
torch
.
cat
(
rand_grad
.
split
(
5
,
1
),
0
)
ref_res
=
ref_model
(
torch
.
tensor
([
0
,
1
,
3
,
0
,
2
],
device
=
device
),
torch
.
tensor
([
0
,
2
,
3
,
5
],
device
=
device
))
ref_res
=
ref_model
(
torch
.
tensor
([
1
,
2
,
3
,
1
,
5
,
6
,
7
,
9
,
6
,
8
,
13
,
15
,
11
],
device
=
device
),
ref_res
.
backward
(
ref_grad
)
torch
.
tensor
([
0
,
3
,
3
,
5
,
7
,
8
,
10
,
10
,
12
,
13
],
device
=
device
))
ref_res
.
backward
(
ref_fake_grad
)
ref_optimizer
.
step
()
ref_optimizer
.
step
()
ref_optimizer
.
zero_grad
()
ref_optimizer
.
zero_grad
()
model
.
freq_aware_embedding_bag_list
[
1
].
cache_weight_mgr
.
flush
()
# update cpu weight
model
.
cache_weight_mgr
.
flush
()
recover_weight
=
model
.
freq_aware_embedding_bag_list
[
1
].
cache_weight_mgr
.
weight
recover_weight
=
model
.
cache_weight_mgr
.
weight
.
to
(
device
)
assert
torch
.
allclose
(
recover_weight
,
ref_model
.
weight
.
detach
().
cpu
()
ref_weight
=
ref_model
.
weight
.
detach
()[:
11
]
),
f
"
{
recover_weight
-
ref_model
.
weight
.
detach
().
cpu
()
}
"
assert
torch
.
allclose
(
recover_weight
,
ref_weight
),
f
"
{
recover_weight
-
ref_weight
}
"
def
run_parallel_freq_aware_embed_columnwise
(
rank
,
world_size
):
def
run_parallel_freq_aware_embed_columnwise
(
rank
,
world_size
):
device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
())
device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment