Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
5156d5b4
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "4f68b3f10ce55a3563f943f8163b460d8c9fbb19"
Unverified
Commit
5156d5b4
authored
Sep 01, 2022
by
CsRic
Committed by
GitHub
Sep 01, 2022
Browse files
[embedding] add tablewise sharding for FAW (#1526)
parent
f1e18362
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
273 additions
and
13 deletions
+273
-13
colossalai/nn/parallel/layers/__init__.py
colossalai/nn/parallel/layers/__init__.py
+3
-2
colossalai/nn/parallel/layers/cache_embedding/__init__.py
colossalai/nn/parallel/layers/cache_embedding/__init__.py
+2
-2
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py
...n/parallel/layers/cache_embedding/freq_aware_embedding.py
+1
-1
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
...l/layers/cache_embedding/parallel_freq_aware_embedding.py
+0
-2
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
...ache_embedding/parallel_freq_aware_embedding_tablewise.py
+192
-0
tests/test_layers/test_cache_embedding.py
tests/test_layers/test_cache_embedding.py
+75
-6
No files found.
colossalai/nn/parallel/layers/__init__.py
View file @
5156d5b4
...
@@ -3,10 +3,11 @@ from .linear import ColoLinear
...
@@ -3,10 +3,11 @@ from .linear import ColoLinear
from
.embedding
import
ColoEmbedding
from
.embedding
import
ColoEmbedding
from
.module_utils
import
register_colo_module
,
is_colo_module
,
get_colo_module
,
init_colo_module
,
check_colo_module
from
.module_utils
import
register_colo_module
,
is_colo_module
,
get_colo_module
,
init_colo_module
,
check_colo_module
from
.cache_embedding
import
FreqAwareEmbeddingBag
,
ParallelFreqAwareEmbeddingBag
,
CachedParamMgr
,
LimitBuffIndexCopyer
,
EvictionStrategy
from
.cache_embedding
import
FreqAwareEmbeddingBag
,
ParallelFreqAwareEmbeddingBag
,
CachedParamMgr
,
LimitBuffIndexCopyer
,
EvictionStrategy
,
\
ParallelFreqAwareEmbeddingBagTablewise
,
TablewiseEmbeddingBagConfig
__all__
=
[
__all__
=
[
'ColoModule'
,
'register_colo_module'
,
'is_colo_module'
,
'get_colo_module'
,
'init_colo_module'
,
'check_colo_module'
,
'ColoModule'
,
'register_colo_module'
,
'is_colo_module'
,
'get_colo_module'
,
'init_colo_module'
,
'check_colo_module'
,
'ColoLinear'
,
'ColoEmbedding'
,
'FreqAwareEmbeddingBag'
,
'ParallelFreqAwareEmbeddingBag'
,
'CachedParamMgr'
,
'ColoLinear'
,
'ColoEmbedding'
,
'FreqAwareEmbeddingBag'
,
'ParallelFreqAwareEmbeddingBag'
,
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'EvictionStrategy'
'LimitBuffIndexCopyer'
,
'EvictionStrategy'
,
'ParallelFreqAwareEmbeddingBagTablewise'
,
'TablewiseEmbeddingBagConfig'
]
]
colossalai/nn/parallel/layers/cache_embedding/__init__.py
View file @
5156d5b4
...
@@ -2,8 +2,8 @@ from .cache_mgr import CachedParamMgr, EvictionStrategy
...
@@ -2,8 +2,8 @@ from .cache_mgr import CachedParamMgr, EvictionStrategy
from
.copyer
import
LimitBuffIndexCopyer
from
.copyer
import
LimitBuffIndexCopyer
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
.parallel_freq_aware_embedding
import
ParallelFreqAwareEmbeddingBag
from
.parallel_freq_aware_embedding
import
ParallelFreqAwareEmbeddingBag
from
.parallel_freq_aware_embedding_tablewise
import
ParallelFreqAwareEmbeddingBagTablewise
,
TablewiseEmbeddingBagConfig
__all__
=
[
__all__
=
[
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'FreqAwareEmbeddingBag'
,
'ParallelFreqAwareEmbeddingBag'
,
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'FreqAwareEmbeddingBag'
,
'ParallelFreqAwareEmbeddingBag'
,
'EvictionStrategy'
'EvictionStrategy'
,
'ParallelFreqAwareEmbeddingBagTablewise'
,
'TablewiseEmbeddingBagConfig'
]
]
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py
View file @
5156d5b4
...
@@ -99,7 +99,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
...
@@ -99,7 +99,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
def
forward
(
self
,
indices
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
):
def
forward
(
self
,
indices
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
indices
)
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
indices
)
embeddings
=
F
.
embedding_bag
(
reorder_ids
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
embeddings
=
F
.
embedding_bag
(
reorder_ids
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
...
...
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
View file @
5156d5b4
...
@@ -79,10 +79,8 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
...
@@ -79,10 +79,8 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
output_shard
=
F
.
embedding_bag
(
reorder_ids
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
output_shard
=
F
.
embedding_bag
(
reorder_ids
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
if
shape_hook
is
not
None
:
if
shape_hook
is
not
None
:
output_shard
=
shape_hook
(
output_shard
)
output_shard
=
shape_hook
(
output_shard
)
output_full
=
dual_all_to_all
(
output_shard
,
output_full
=
dual_all_to_all
(
output_shard
,
self
.
weight
.
get_process_group
(),
self
.
weight
.
get_process_group
(),
scatter_dim
=
scatter_dim
,
scatter_dim
=
scatter_dim
,
...
...
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
0 → 100644
View file @
5156d5b4
import
torch
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
typing
import
List
,
Optional
,
Iterator
,
Tuple
import
abc
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
colossalai.tensor
import
ColoParameter
,
ShardSpec
,
ComputePattern
,
ProcessGroup
,
ColoTensorSpec
,
ColoTensor
from
.cache_mgr
import
CachedParamMgr
,
EvictionStrategy
class
TablewiseEmbeddingBagConfig
:
'''
example:
def prepare_tablewise_config(args, cache_ratio, ...):
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
...
return embedding_bag_config_list
'''
def
__init__
(
self
,
num_embeddings
:
int
,
cuda_row_num
:
int
,
assigned_rank
:
int
=
0
,
buffer_size
=
50_000
,
ids_freq_mapping
=
None
,
initial_weight
:
torch
.
tensor
=
None
,
name
:
str
=
""
):
self
.
num_embeddings
=
num_embeddings
self
.
cuda_row_num
=
cuda_row_num
self
.
assigned_rank
=
assigned_rank
self
.
buffer_size
=
buffer_size
self
.
ids_freq_mapping
=
ids_freq_mapping
self
.
initial_weight
=
initial_weight
self
.
name
=
name
def
_all_to_all_for_tablewise
(
x
:
torch
.
Tensor
,
pg
:
ProcessGroup
,
scatter_strides
:
List
[
int
],
gather_strides
:
List
[
int
],
forward
=
True
)
->
torch
.
Tensor
:
world_size
=
pg
.
tp_world_size
()
rank
=
pg
.
tp_local_rank
()
if
world_size
==
1
:
return
x
assert
x
.
device
.
type
==
'cuda'
,
f
"Currently, the collective function dual_all_to_all only supports nccl backend"
if
forward
:
scatter_list
=
list
(
x
.
split
(
scatter_strides
,
0
))
gather_list
=
[
torch
.
empty
(
scatter_strides
[
rank
],
gather_strides
[
i
],
dtype
=
x
.
dtype
,
device
=
x
.
device
)
for
i
in
range
(
world_size
)]
torch
.
distributed
.
all_to_all
(
gather_list
,
scatter_list
,
group
=
pg
.
tp_process_group
())
return
torch
.
cat
(
gather_list
,
1
).
contiguous
()
else
:
# split on dim 1, lose contiguity
scatter_list
=
[
each
.
contiguous
()
for
each
in
x
.
split
(
scatter_strides
,
1
)]
gather_list
=
[
torch
.
empty
(
gather_strides
[
i
],
scatter_strides
[
rank
],
dtype
=
x
.
dtype
,
device
=
x
.
device
)
for
i
in
range
(
world_size
)]
torch
.
distributed
.
all_to_all
(
gather_list
,
scatter_list
,
group
=
pg
.
tp_process_group
())
return
torch
.
cat
(
gather_list
,
0
).
contiguous
()
class
_DualAllToAllForTablewise
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
pg
,
scatter_strides
,
gather_strides
):
ctx
.
pg
=
pg
ctx
.
scatter_strides
=
scatter_strides
ctx
.
gather_strides
=
gather_strides
return
_all_to_all_for_tablewise
(
x
,
pg
,
scatter_strides
,
gather_strides
,
forward
=
True
)
@
staticmethod
def
backward
(
ctx
,
grad
):
return
_all_to_all_for_tablewise
(
grad
,
ctx
.
pg
,
ctx
.
gather_strides
,
ctx
.
scatter_strides
,
forward
=
False
),
None
,
None
,
None
def
_dual_all_to_all
(
x
,
pg
,
scatter_strides
,
gather_strides
):
return
_DualAllToAllForTablewise
.
apply
(
x
,
pg
,
scatter_strides
,
gather_strides
)
class
ParallelFreqAwareEmbeddingBagTablewise
(
abc
.
ABC
,
nn
.
Module
):
'''
every table assigned to this class instance is managed by a FreqAwareEmbeddingBag.
'''
def
__init__
(
self
,
embedding_bag_config_list
:
List
[
TablewiseEmbeddingBagConfig
],
embedding_dim
:
int
,
padding_idx
=
None
,
max_norm
=
None
,
norm_type
=
2.
,
scale_grad_by_freq
=
False
,
sparse
=
False
,
mode
=
'mean'
,
include_last_offset
=
False
,
dtype
=
None
,
device
=
None
,
warmup_ratio
=
0.7
,
pin_weight
=
False
,
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
LFU
):
super
(
ParallelFreqAwareEmbeddingBagTablewise
,
self
).
__init__
()
self
.
rank
=
dist
.
get_rank
()
self
.
world_size
=
dist
.
get_world_size
()
self
.
global_table_assign_list
=
[
config
.
assigned_rank
for
config
in
embedding_bag_config_list
]
self
.
global_table_num_embeddings_list
=
[
config
.
num_embeddings
for
config
in
embedding_bag_config_list
]
self
.
global_tables_num
=
len
(
embedding_bag_config_list
)
self
.
global_tables_offsets
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
self
.
global_table_num_embeddings_list
),
0
)
self
.
assigned_table_list
:
List
[
int
]
=
[]
for
i
,
rank
in
enumerate
(
self
.
global_table_assign_list
):
if
rank
==
self
.
rank
:
self
.
assigned_table_list
.
append
(
i
)
self
.
include_last_offset
=
include_last_offset
self
.
pg
=
ProcessGroup
(
tp_degree
=
self
.
world_size
)
# prepare FreqAwareEmbeddingBag list
self
.
freq_aware_embedding_bag_list
:
nn
.
ModuleList
=
nn
.
ModuleList
()
for
config
in
embedding_bag_config_list
:
if
config
.
assigned_rank
!=
self
.
rank
:
continue
self
.
freq_aware_embedding_bag_list
.
append
(
FreqAwareEmbeddingBag
(
num_embeddings
=
config
.
num_embeddings
,
embedding_dim
=
embedding_dim
,
padding_idx
=
padding_idx
,
max_norm
=
max_norm
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
,
_weight
=
config
.
initial_weight
,
mode
=
mode
,
include_last_offset
=
include_last_offset
,
dtype
=
dtype
,
device
=
device
,
cuda_row_num
=
config
.
cuda_row_num
,
ids_freq_mapping
=
config
.
ids_freq_mapping
,
warmup_ratio
=
warmup_ratio
,
buffer_size
=
config
.
buffer_size
,
pin_weight
=
pin_weight
,
evict_strategy
=
evict_strategy
)
)
# prepare list shape for all_to_all output
self
.
embedding_dim_per_rank
=
[
0
for
i
in
range
(
self
.
world_size
)]
for
rank
in
self
.
global_table_assign_list
:
self
.
embedding_dim_per_rank
[
rank
]
+=
embedding_dim
#print("global_table_assign_list {}".format(self.global_table_assign_list))
#print("global_table_num_embeddings_list {}".format(self.global_table_num_embeddings_list))
#print("global_tables_offsets {}".format(self.global_tables_offsets))
#
def
forward
(
self
,
indices
:
torch
.
Tensor
,
offsets
:
torch
.
Tensor
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
):
# determine indices to handle
batch_size
=
(
offsets
.
shape
[
0
])
//
self
.
global_tables_num
local_output_list
=
[]
for
i
,
handle_table
in
enumerate
(
self
.
assigned_table_list
):
indices_start_position
=
offsets
[
batch_size
*
handle_table
]
if
(
not
self
.
include_last_offset
)
and
(
batch_size
*
(
handle_table
+
1
)
>=
indices
.
shape
[
0
]):
# till the end special case
indices_end_position
=
indices
.
shape
[
0
]
else
:
indices_end_position
=
offsets
[
batch_size
*
(
handle_table
+
1
)]
local_indices
=
indices
[
indices_start_position
:
indices_end_position
]
-
\
self
.
global_tables_offsets
[
handle_table
]
if
self
.
include_last_offset
:
local_offsets
=
offsets
[
batch_size
*
handle_table
:
batch_size
*
(
handle_table
+
1
)
+
1
]
-
offsets
[
batch_size
*
(
handle_table
)]
else
:
local_offsets
=
offsets
[
batch_size
*
handle_table
:
batch_size
*
(
handle_table
+
1
)]
-
offsets
[
batch_size
*
(
handle_table
)]
local_per_sample_weights
=
None
if
per_sample_weights
!=
None
:
local_per_sample_weights
=
per_sample_weights
[
indices_start_position
:
indices_end_position
]
local_output_list
.
append
(
self
.
freq_aware_embedding_bag_list
[
i
](
local_indices
,
local_offsets
,
local_per_sample_weights
)
)
# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
local_output
=
torch
.
cat
(
local_output_list
,
1
)
# then concatenate those local_output on the second demension.
# use all_to_all
remains
=
batch_size
%
self
.
world_size
scatter_strides
=
[
batch_size
//
self
.
world_size
+
int
(
i
<
remains
)
for
i
in
range
(
self
.
world_size
)]
output_full
=
_dual_all_to_all
(
local_output
,
self
.
pg
,
scatter_strides
,
self
.
embedding_dim_per_rank
)
if
shape_hook
is
not
None
:
output_full
=
shape_hook
(
output_full
)
return
output_full
tests/test_layers/test_cache_embedding.py
View file @
5156d5b4
...
@@ -12,7 +12,9 @@ from colossalai.utils import free_port
...
@@ -12,7 +12,9 @@ from colossalai.utils import free_port
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.tensor
import
ColoParameter
,
ProcessGroup
,
ShardSpec
,
ComputePattern
,
ComputeSpec
,
\
from
colossalai.tensor
import
ColoParameter
,
ProcessGroup
,
ShardSpec
,
ComputePattern
,
ComputeSpec
,
\
ColoTensor
,
ColoTensorSpec
ColoTensor
,
ColoTensorSpec
from
colossalai.nn.parallel.layers
import
CachedParamMgr
,
FreqAwareEmbeddingBag
,
ParallelFreqAwareEmbeddingBag
,
EvictionStrategy
from
colossalai.nn.parallel.layers
import
CachedParamMgr
,
FreqAwareEmbeddingBag
,
ParallelFreqAwareEmbeddingBag
,
EvictionStrategy
,
\
ParallelFreqAwareEmbeddingBagTablewise
,
TablewiseEmbeddingBagConfig
from
typing
import
List
NUM_EMBED
,
EMBED_DIM
=
10
,
8
NUM_EMBED
,
EMBED_DIM
=
10
,
8
BATCH_SIZE
=
8
BATCH_SIZE
=
8
...
@@ -200,7 +202,72 @@ def gather_tensor(tensor, rank, world_size):
...
@@ -200,7 +202,72 @@ def gather_tensor(tensor, rank, world_size):
return
gather_list
return
gather_list
def
run_parallel_freq_aware_embed
(
rank
,
world_size
):
def
run_parallel_freq_aware_embed_tablewise
(
rank
,
world_size
):
if
world_size
!=
2
:
return
device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
())
# initialize weight
# 3 feature tables. idx: 0~5, 6~10, 11~17
weight_table1
=
torch
.
rand
(
6
,
5
)
weight_table2
=
torch
.
rand
(
5
,
5
)
weight_table3
=
torch
.
rand
(
7
,
5
)
embedding_bag_config_list
:
List
[
TablewiseEmbeddingBagConfig
]
=
[]
embedding_bag_config_list
.
append
(
TablewiseEmbeddingBagConfig
(
num_embeddings
=
6
,
cuda_row_num
=
4
,
assigned_rank
=
0
,
initial_weight
=
weight_table1
.
clone
().
detach
().
cpu
()))
embedding_bag_config_list
.
append
(
TablewiseEmbeddingBagConfig
(
num_embeddings
=
5
,
cuda_row_num
=
4
,
assigned_rank
=
0
,
initial_weight
=
weight_table2
.
clone
().
detach
().
cpu
()))
embedding_bag_config_list
.
append
(
TablewiseEmbeddingBagConfig
(
num_embeddings
=
7
,
cuda_row_num
=
4
,
assigned_rank
=
1
,
initial_weight
=
weight_table3
.
clone
().
detach
().
cpu
()))
model
=
ParallelFreqAwareEmbeddingBagTablewise
(
embedding_bag_config_list
,
embedding_dim
=
5
,
evict_strategy
=
EvictionStrategy
.
LFU
,
include_last_offset
=
True
)
# demo explain:
'''
batch feature 1 feature 2 feature 3
input0 [1,2,3] [6,7] []
input1 [] [9] [13,15]
input2 [1,5] [6,8] [11]
↑ ↑ ↑
rank 0 rank 0 rank 1
in KJT format
'''
res
=
model
(
torch
.
tensor
([
1
,
2
,
3
,
1
,
5
,
6
,
7
,
9
,
6
,
8
,
13
,
15
,
11
],
device
=
device
),
torch
.
tensor
([
0
,
3
,
3
,
5
,
7
,
8
,
10
,
10
,
12
,
13
],
device
=
device
))
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
1e-2
)
rand_grad
=
torch
.
rand
(
3
,
5
*
3
,
dtype
=
res
.
dtype
,
device
=
res
.
device
)
if
rank
==
0
:
fake_grad
=
rand_grad
[
0
:
2
]
else
:
fake_grad
=
rand_grad
[
2
:]
res
.
backward
(
fake_grad
)
optimizer
.
step
()
optimizer
.
zero_grad
()
# check correctness on weight_table2
if
rank
==
0
:
ref_model
=
torch
.
nn
.
EmbeddingBag
.
from_pretrained
(
weight_table2
.
detach
().
clone
(),
include_last_offset
=
True
,
freeze
=
False
).
to
(
device
)
ref_optimizer
=
torch
.
optim
.
SGD
(
ref_model
.
parameters
(),
lr
=
1e-2
)
ref_grad
=
rand_grad
[:,
5
:
10
]
ref_res
=
ref_model
(
torch
.
tensor
([
0
,
1
,
3
,
0
,
2
],
device
=
device
),
torch
.
tensor
([
0
,
2
,
3
,
5
],
device
=
device
))
ref_res
.
backward
(
ref_grad
)
ref_optimizer
.
step
()
ref_optimizer
.
zero_grad
()
model
.
freq_aware_embedding_bag_list
[
1
].
cache_weight_mgr
.
flush
()
# update cpu weight
recover_weight
=
model
.
freq_aware_embedding_bag_list
[
1
].
cache_weight_mgr
.
weight
assert
torch
.
allclose
(
recover_weight
,
ref_model
.
weight
.
detach
().
cpu
()
),
f
"
{
recover_weight
-
ref_model
.
weight
.
detach
().
cpu
()
}
"
def
run_parallel_freq_aware_embed_columnwise
(
rank
,
world_size
):
device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
())
device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
())
num_embed
=
100
num_embed
=
100
...
@@ -219,7 +286,8 @@ def run_parallel_freq_aware_embed(rank, world_size):
...
@@ -219,7 +286,8 @@ def run_parallel_freq_aware_embed(rank, world_size):
model
=
ParallelFreqAwareEmbeddingBag
.
from_pretrained
(
coloweight
,
model
=
ParallelFreqAwareEmbeddingBag
.
from_pretrained
(
coloweight
,
include_last_offset
=
True
,
include_last_offset
=
True
,
freeze
=
False
,
freeze
=
False
,
cuda_row_num
=
batch_size
*
2
)
cuda_row_num
=
batch_size
*
2
,
)
assert
model
.
cache_weight_mgr
.
weight
.
device
.
type
==
'cpu'
assert
model
.
cache_weight_mgr
.
weight
.
device
.
type
==
'cpu'
assert
model
.
cache_weight_mgr
.
cuda_cached_weight
.
requires_grad
assert
model
.
cache_weight_mgr
.
cuda_cached_weight
.
requires_grad
...
@@ -269,7 +337,8 @@ def run_parallel_freq_aware_embed(rank, world_size):
...
@@ -269,7 +337,8 @@ def run_parallel_freq_aware_embed(rank, world_size):
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_parallel_freq_aware_embed
(
rank
,
world_size
)
# run_parallel_freq_aware_embed_columnwise(rank, world_size)
run_parallel_freq_aware_embed_tablewise
(
rank
,
world_size
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
@@ -281,6 +350,6 @@ def test_parallel_freq_aware_embed(world_size):
...
@@ -281,6 +350,6 @@ def test_parallel_freq_aware_embed(world_size):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_freq_aware_embed
(
True
)
#
test_freq_aware_embed(True)
#
test_parallel_freq_aware_embed(2)
test_parallel_freq_aware_embed
(
2
)
# test_lfu_strategy(False)
# test_lfu_strategy(False)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment