Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
21962e15
Unverified
Commit
21962e15
authored
Oct 13, 2022
by
Jiarui Fang
Committed by
GitHub
Oct 13, 2022
Browse files
[embedding] rename FreqAwareEmbedding -> CachedEmbedding (#1699)
parent
0e52f3d3
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
77 additions
and
76 deletions
+77
-76
colossalai/nn/parallel/layers/__init__.py
colossalai/nn/parallel/layers/__init__.py
+5
-5
colossalai/nn/parallel/layers/cache_embedding/__init__.py
colossalai/nn/parallel/layers/cache_embedding/__init__.py
+7
-7
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+2
-1
colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py
...ai/nn/parallel/layers/cache_embedding/cached_embedding.py
+5
-5
colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py
...allel/layers/cache_embedding/parallel_cached_embedding.py
+4
-4
colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py
...rs/cache_embedding/parallel_cached_embedding_tablewise.py
+4
-4
colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py
...edding/parallel_cached_embedding_tablewise_split_cache.py
+31
-31
tests/test_layers/test_cache_embedding.py
tests/test_layers/test_cache_embedding.py
+19
-19
No files found.
colossalai/nn/parallel/layers/__init__.py
View file @
21962e15
...
...
@@ -3,12 +3,12 @@ from .linear import ColoLinear
from
.embedding
import
ColoEmbedding
from
.module_utils
import
register_colo_module
,
is_colo_module
,
get_colo_module
,
init_colo_module
,
check_colo_module
from
.cache_embedding
import
FreqAware
EmbeddingBag
,
Parallel
FreqAware
EmbeddingBag
,
CachedParamMgr
,
LimitBuffIndexCopyer
,
EvictionStrategy
,
\
Parallel
FreqAware
EmbeddingBagTablewise
,
TablewiseEmbeddingBagConfig
,
Parallel
FreqAware
EmbeddingBagTablewiseSpiltCache
from
.cache_embedding
import
Cached
EmbeddingBag
,
Parallel
Cached
EmbeddingBag
,
CachedParamMgr
,
LimitBuffIndexCopyer
,
EvictionStrategy
,
\
Parallel
Cached
EmbeddingBagTablewise
,
TablewiseEmbeddingBagConfig
,
Parallel
Cached
EmbeddingBagTablewiseSpiltCache
__all__
=
[
'ColoModule'
,
'register_colo_module'
,
'is_colo_module'
,
'get_colo_module'
,
'init_colo_module'
,
'check_colo_module'
,
'ColoLinear'
,
'ColoEmbedding'
,
'
FreqAware
EmbeddingBag'
,
'Parallel
FreqAware
EmbeddingBag'
,
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'EvictionStrategy'
,
'Parallel
FreqAware
EmbeddingBagTablewise'
,
'TablewiseEmbeddingBagConfig'
,
'Parallel
FreqAware
EmbeddingBagTablewiseSpiltCache'
'ColoLinear'
,
'ColoEmbedding'
,
'
Cached
EmbeddingBag'
,
'Parallel
Cached
EmbeddingBag'
,
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'EvictionStrategy'
,
'Parallel
Cached
EmbeddingBagTablewise'
,
'TablewiseEmbeddingBagConfig'
,
'Parallel
Cached
EmbeddingBagTablewiseSpiltCache'
]
colossalai/nn/parallel/layers/cache_embedding/__init__.py
View file @
21962e15
from
.cache_mgr
import
CachedParamMgr
,
EvictionStrategy
from
.copyer
import
LimitBuffIndexCopyer
from
.
freq_aware
_embedding
import
FreqAware
EmbeddingBag
from
.parallel_
freq_aware
_embedding
import
Parallel
FreqAware
EmbeddingBag
from
.
cached
_embedding
import
Cached
EmbeddingBag
from
.parallel_
cached
_embedding
import
Parallel
Cached
EmbeddingBag
from
.embedding_config
import
TablewiseEmbeddingBagConfig
from
.parallel_
freq_aware
_embedding_tablewise
import
Parallel
FreqAware
EmbeddingBagTablewise
from
.parallel_
freq_aware
_embedding_tablewise_split_cache
import
Parallel
FreqAware
EmbeddingBagTablewiseSpiltCache
from
.parallel_
cached
_embedding_tablewise
import
Parallel
Cached
EmbeddingBagTablewise
from
.parallel_
cached
_embedding_tablewise_split_cache
import
Parallel
Cached
EmbeddingBagTablewiseSpiltCache
__all__
=
[
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'
FreqAware
EmbeddingBag'
,
'Parallel
FreqAware
EmbeddingBag'
,
'
EvictionStrategy'
,
'ParallelFreqAware
EmbeddingBagTablewise'
,
'TablewiseEmbeddingBagConfig'
,
'Parallel
FreqAware
EmbeddingBagTablewiseSpiltCache'
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'
Cached
EmbeddingBag'
,
'Parallel
Cached
EmbeddingBag'
,
'EvictionStrategy'
,
'
ParallelCached
EmbeddingBagTablewise'
,
'TablewiseEmbeddingBagConfig'
,
'Parallel
Cached
EmbeddingBagTablewiseSpiltCache'
]
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
View file @
21962e15
...
...
@@ -352,7 +352,8 @@ class CachedParamMgr(torch.nn.Module):
# move sure the cuda rows will not be evicted!
with
record_function
(
"(cache) prepare_rows_on_cuda"
):
self
.
_prepare_rows_on_cuda
(
comm_cpu_row_idxs
)
with
self
.
timer
(
"prepare_rows_on_cuda"
)
as
timer
:
self
.
_prepare_rows_on_cuda
(
comm_cpu_row_idxs
)
self
.
evict_backlist
=
torch
.
tensor
([],
device
=
cpu_row_idxs
.
device
,
dtype
=
cpu_row_idxs
.
dtype
)
...
...
colossalai/nn/parallel/layers/cache_embedding/
freq_aware
_embedding.py
→
colossalai/nn/parallel/layers/cache_embedding/
cached
_embedding.py
View file @
21962e15
...
...
@@ -7,10 +7,10 @@ from .cache_mgr import CachedParamMgr, EvictionStrategy
from
torch.nn.parameter
import
Parameter
class
FreqAware
EmbeddingBag
(
BaseEmbeddingBag
):
"""
FreqAware
EmbeddingBag
class
Cached
EmbeddingBag
(
BaseEmbeddingBag
):
"""
Cached
EmbeddingBag
Frequency Aware
Embedding. Apply a GPU-based software cache approaches to dynamically manage the embedding table in the CPU and GPU memory space.
Cached
Embedding. Apply a GPU-based software cache approaches to dynamically manage the embedding table in the CPU and GPU memory space.
It can leverage the id's frequency statistics of the target dataset, by passing a frequency list to param `ids_freq_mapping`.
You can also apply a navie LFU cache eviction strategy by setting `evict_strategy` as EvictionStrategy.LFU.
...
...
@@ -54,8 +54,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
buffer_size
:
int
=
0
,
pin_weight
:
bool
=
False
,
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
LFU
):
super
(
FreqAware
EmbeddingBag
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
mode
,
include_last_offset
)
super
(
Cached
EmbeddingBag
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
mode
,
include_last_offset
)
assert
cache_ratio
<=
1.0
,
f
"cache ratio
{
cache_ratio
}
must less than 1.0"
self
.
evict_strategy
=
evict_strategy
...
...
colossalai/nn/parallel/layers/cache_embedding/parallel_
freq_aware
_embedding.py
→
colossalai/nn/parallel/layers/cache_embedding/parallel_
cached
_embedding.py
View file @
21962e15
...
...
@@ -2,7 +2,7 @@ import torch
import
torch.nn.functional
as
F
from
typing
import
List
,
Optional
,
Iterator
,
Tuple
from
.
freq_aware
_embedding
import
FreqAware
EmbeddingBag
from
.
cached
_embedding
import
Cached
EmbeddingBag
from
colossalai.nn._ops._utils
import
dual_all_to_all
from
colossalai.tensor
import
ColoParameter
,
ShardSpec
,
ComputePattern
,
ProcessGroup
,
ColoTensorSpec
,
ColoTensor
...
...
@@ -28,7 +28,7 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
return
offset
,
offset
+
size_list
[
rank
],
False
class
Parallel
FreqAware
EmbeddingBag
(
FreqAware
EmbeddingBag
):
class
Parallel
Cached
EmbeddingBag
(
Cached
EmbeddingBag
):
def
__init__
(
self
,
num_embeddings
,
...
...
@@ -56,7 +56,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
embedding_dim
,
self
.
rank
,
self
.
world_size
)
self
.
embedding_dim_per_partition
=
self
.
partition_end_index
-
self
.
partition_start_index
super
(
Parallel
FreqAware
EmbeddingBag
,
super
(
Parallel
Cached
EmbeddingBag
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
,
mode
,
include_last_offset
,
dtype
,
device
,
cache_ratio
,
ids_freq_mapping
,
warmup_ratio
,
buffer_size
,
pin_weight
,
evict_strategy
)
...
...
@@ -115,7 +115,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
ids_freq_mapping
:
Optional
[
List
[
int
]]
=
None
,
warmup_ratio
:
float
=
0.7
,
buffer_size
:
int
=
0
,
)
->
'Parallel
FreqAware
EmbeddingBag'
:
)
->
'Parallel
Cached
EmbeddingBag'
:
rows
,
cols
=
embedding
.
shape
embedding_bag
=
cls
(
rows
,
cols
,
...
...
colossalai/nn/parallel/layers/cache_embedding/parallel_
freq_aware
_embedding_tablewise.py
→
colossalai/nn/parallel/layers/cache_embedding/parallel_
cached
_embedding_tablewise.py
View file @
21962e15
...
...
@@ -2,7 +2,7 @@ import torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
.
freq_aware
_embedding
import
FreqAware
EmbeddingBag
from
.
cached
_embedding
import
Cached
EmbeddingBag
from
.cache_mgr
import
EvictionStrategy
from
.embedding_config
import
TablewiseEmbeddingBagConfig
from
colossalai.tensor
import
ProcessGroup
...
...
@@ -12,9 +12,9 @@ from typing import List
import
time
class
Parallel
FreqAware
EmbeddingBagTablewise
(
FreqAware
EmbeddingBag
):
class
Parallel
Cached
EmbeddingBagTablewise
(
Cached
EmbeddingBag
):
"""
all tables assigned to this class instance are managed by a single
FreqAware
EmbeddingBag.
all tables assigned to this class instance are managed by a single
Cached
EmbeddingBag.
Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight.
"""
...
...
@@ -62,7 +62,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
self
.
cache_ratio
=
cache_ratio
# table-associate cache
cuda_row_num
=
int
(
cache_ratio
*
self
.
num_embeddings
)
super
(
Parallel
FreqAware
EmbeddingBagTablewise
,
super
(
Parallel
Cached
EmbeddingBagTablewise
,
self
).
__init__
(
self
.
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
,
mode
,
include_last_offset
,
dtype
,
device
,
cache_ratio
,
ids_freq_mapping
,
warmup_ratio
,
buffer_size
,
pin_weight
,
evict_strategy
)
...
...
colossalai/nn/parallel/layers/cache_embedding/parallel_
freq_aware
_embedding_tablewise_split_cache.py
→
colossalai/nn/parallel/layers/cache_embedding/parallel_
cached
_embedding_tablewise_split_cache.py
View file @
21962e15
...
...
@@ -3,7 +3,7 @@ import torch.distributed as dist
import
torch.nn
as
nn
from
torch.profiler
import
record_function
from
.
freq_aware
_embedding
import
FreqAware
EmbeddingBag
from
.
cached
_embedding
import
Cached
EmbeddingBag
from
colossalai.tensor
import
ProcessGroup
from
colossalai.nn._ops._utils
import
dual_all_to_all_tablewise
...
...
@@ -14,9 +14,9 @@ from typing import List
import
abc
class
Parallel
FreqAware
EmbeddingBagTablewiseSpiltCache
(
abc
.
ABC
,
nn
.
Module
):
class
Parallel
Cached
EmbeddingBagTablewiseSpiltCache
(
abc
.
ABC
,
nn
.
Module
):
"""
every table assigned to this class instance is managed by a
FreqAware
EmbeddingBag.
every table assigned to this class instance is managed by a
Cached
EmbeddingBag.
"""
def
__init__
(
self
,
...
...
@@ -34,7 +34,7 @@ class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
warmup_ratio
=
0.7
,
pin_weight
=
False
,
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
LFU
):
super
(
Parallel
FreqAware
EmbeddingBagTablewiseSpiltCache
,
self
).
__init__
()
super
(
Parallel
Cached
EmbeddingBagTablewiseSpiltCache
,
self
).
__init__
()
self
.
rank
=
dist
.
get_rank
()
self
.
world_size
=
dist
.
get_world_size
()
self
.
rank_of_tables
=
[
config
.
assigned_rank
for
config
in
embedding_bag_config_list
]
...
...
@@ -49,31 +49,31 @@ class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
self
.
include_last_offset
=
include_last_offset
self
.
pg
=
ProcessGroup
(
tp_degree
=
self
.
world_size
)
# prepare
FreqAware
EmbeddingBag list
# prepare
Cached
EmbeddingBag list
self
.
freq_aware
_embedding_bag_list
:
nn
.
ModuleList
=
nn
.
ModuleList
()
self
.
cached
_embedding_bag_list
:
nn
.
ModuleList
=
nn
.
ModuleList
()
for
config
in
embedding_bag_config_list
:
if
config
.
assigned_rank
!=
self
.
rank
:
continue
self
.
freq_aware
_embedding_bag_list
.
append
(
FreqAware
EmbeddingBag
(
num_embeddings
=
config
.
num_embeddings
,
embedding_dim
=
embedding_dim
,
padding_idx
=
padding_idx
,
max_norm
=
max_norm
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
,
_weight
=
config
.
initial_weight
,
mode
=
mode
,
include_last_offset
=
include_last_offset
,
dtype
=
dtype
,
device
=
device
,
cuda_row_num
=
config
.
cuda_row_num
,
ids_freq_mapping
=
config
.
ids_freq_mapping
,
warmup_ratio
=
warmup_ratio
,
buffer_size
=
config
.
buffer_size
,
pin_weight
=
pin_weight
,
evict_strategy
=
evict_strategy
))
self
.
cached
_embedding_bag_list
.
append
(
Cached
EmbeddingBag
(
num_embeddings
=
config
.
num_embeddings
,
embedding_dim
=
embedding_dim
,
padding_idx
=
padding_idx
,
max_norm
=
max_norm
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
,
_weight
=
config
.
initial_weight
,
mode
=
mode
,
include_last_offset
=
include_last_offset
,
dtype
=
dtype
,
device
=
device
,
cuda_row_num
=
config
.
cuda_row_num
,
ids_freq_mapping
=
config
.
ids_freq_mapping
,
warmup_ratio
=
warmup_ratio
,
buffer_size
=
config
.
buffer_size
,
pin_weight
=
pin_weight
,
evict_strategy
=
evict_strategy
))
# prepare list shape for all_to_all output
self
.
embedding_dim_per_rank
=
[
0
for
i
in
range
(
self
.
world_size
)]
...
...
@@ -109,8 +109,8 @@ class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
if
per_sample_weights
!=
None
:
local_per_sample_weights
=
per_sample_weights
[
indices_start_position
:
indices_end_position
]
with
record_function
(
"(tablewise) tablewise forward"
):
local_output_list
.
append
(
self
.
freq_aware
_embedding_bag_list
[
i
](
local_indices
,
local_offsets
,
local_per_sample_weights
))
local_output_list
.
append
(
self
.
cached
_embedding_bag_list
[
i
](
local_indices
,
local_offsets
,
local_per_sample_weights
))
# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
local_output
=
torch
.
cat
(
local_output_list
,
1
)
...
...
@@ -126,13 +126,13 @@ class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
def
element_size
(
self
):
if
len
(
self
.
assigned_table_list
)
==
0
:
return
0
return
self
.
freq_aware
_embedding_bag_list
[
0
].
cache_weight_mgr
.
weight
.
element_size
()
return
self
.
cached
_embedding_bag_list
[
0
].
cache_weight_mgr
.
weight
.
element_size
()
def
print_comm_stats_
(
self
):
cuda_to_cpu_elem_num
=
0
cpu_to_cuda_elem_num
=
0
for
freq_aware
_embedding_bag
in
self
.
freq_aware
_embedding_bag_list
:
cuda_to_cpu_elem_num
+=
freq_aware
_embedding_bag
.
cache_weight_mgr
.
_cuda_to_cpu_numel
cpu_to_cuda_elem_num
+=
freq_aware
_embedding_bag
.
cache_weight_mgr
.
_cpu_to_cuda_numel
for
cached
_embedding_bag
in
self
.
cached
_embedding_bag_list
:
cuda_to_cpu_elem_num
+=
cached
_embedding_bag
.
cache_weight_mgr
.
_cuda_to_cpu_numel
cpu_to_cuda_elem_num
+=
cached
_embedding_bag
.
cache_weight_mgr
.
_cpu_to_cuda_numel
print
(
f
"CUDA->CPU num:
{
cuda_to_cpu_elem_num
/
1e6
}
M elem"
)
print
(
f
"CPU->CUDA num:
{
cpu_to_cuda_elem_num
/
1e6
}
M elem"
)
tests/test_layers/test_cache_embedding.py
View file @
21962e15
...
...
@@ -12,8 +12,8 @@ from colossalai.utils import free_port
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.tensor
import
ColoParameter
,
ProcessGroup
,
ShardSpec
,
ComputePattern
,
ComputeSpec
,
\
ColoTensor
,
ColoTensorSpec
from
colossalai.nn.parallel.layers
import
CachedParamMgr
,
FreqAware
EmbeddingBag
,
Parallel
FreqAware
EmbeddingBag
,
EvictionStrategy
,
\
Parallel
FreqAware
EmbeddingBagTablewise
,
TablewiseEmbeddingBagConfig
from
colossalai.nn.parallel.layers
import
CachedParamMgr
,
Cached
EmbeddingBag
,
Parallel
Cached
EmbeddingBag
,
EvictionStrategy
,
\
Parallel
Cached
EmbeddingBagTablewise
,
TablewiseEmbeddingBagConfig
from
typing
import
List
NUM_EMBED
,
EMBED_DIM
=
10
,
8
...
...
@@ -106,13 +106,13 @@ def test_reorder_with_freq():
def
test_freq_aware_embed
(
use_LFU
:
bool
):
device
=
torch
.
device
(
'cuda'
,
0
)
evict_strategy
=
EvictionStrategy
.
LFU
if
use_LFU
else
EvictionStrategy
.
DATASET
model
=
FreqAware
EmbeddingBag
(
NUM_EMBED
,
EMBED_DIM
,
mode
=
'mean'
,
include_last_offset
=
True
,
cache_ratio
=
min
(
BATCH_SIZE
*
2
/
NUM_EMBED
,
1.0
),
ids_freq_mapping
=
None
,
evict_strategy
=
evict_strategy
).
to
(
device
)
model
=
Cached
EmbeddingBag
(
NUM_EMBED
,
EMBED_DIM
,
mode
=
'mean'
,
include_last_offset
=
True
,
cache_ratio
=
min
(
BATCH_SIZE
*
2
/
NUM_EMBED
,
1.0
),
ids_freq_mapping
=
None
,
evict_strategy
=
evict_strategy
).
to
(
device
)
assert
model
.
weight
.
shape
[
0
]
==
NUM_EMBED
ref_model
=
torch
.
nn
.
EmbeddingBag
.
from_pretrained
(
model
.
weight
.
detach
().
to
(
device
),
...
...
@@ -151,14 +151,14 @@ def test_freq_aware_embed(use_LFU: bool):
@
pytest
.
mark
.
parametrize
(
'init_freq'
,
[
True
,
False
])
def
test_lfu_strategy
(
init_freq
:
bool
):
# minimal test to check behavior
Bag
=
FreqAware
EmbeddingBag
(
5
,
5
,
cache_ratio
=
3
/
5
,
buffer_size
=
0
,
pin_weight
=
True
,
ids_freq_mapping
=
[
4
,
2
,
1
,
3
,
1
]
if
init_freq
else
None
,
warmup_ratio
=
1.0
,
evict_strategy
=
EvictionStrategy
.
LFU
)
Bag
=
Cached
EmbeddingBag
(
5
,
5
,
cache_ratio
=
3
/
5
,
buffer_size
=
0
,
pin_weight
=
True
,
ids_freq_mapping
=
[
4
,
2
,
1
,
3
,
1
]
if
init_freq
else
None
,
warmup_ratio
=
1.0
,
evict_strategy
=
EvictionStrategy
.
LFU
)
# print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map)
offsets
=
torch
.
tensor
([
0
],
device
=
"cuda:0"
)
...
...
@@ -233,7 +233,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
_weight
=
torch
.
cat
([
weight_table1
,
weight_table2
],
0
)
else
:
_weight
=
weight_table3
model
=
Parallel
FreqAware
EmbeddingBagTablewise
(
model
=
Parallel
Cached
EmbeddingBagTablewise
(
embedding_bag_config_list
,
embedding_dim
=
5
,
_weight
=
_weight
,
...
...
@@ -300,7 +300,7 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size):
coloweight
.
set_process_group
(
ProcessGroup
(
tp_degree
=
world_size
))
coloweight
.
set_tensor_spec
(
ShardSpec
(
dims
=
[
-
1
],
num_partitions
=
[
world_size
]),
ComputeSpec
(
ComputePattern
.
TP1D
))
model
=
Parallel
FreqAware
EmbeddingBag
.
from_pretrained
(
model
=
Parallel
Cached
EmbeddingBag
.
from_pretrained
(
coloweight
,
include_last_offset
=
True
,
freeze
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment