Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
64169f3e
Unverified
Commit
64169f3e
authored
Sep 06, 2022
by
Jiarui Fang
Committed by
GitHub
Sep 06, 2022
Browse files
[embedding] polish parallel embedding tablewise (#1545)
parent
46c6cc79
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
232 additions
and
204 deletions
+232
-204
colossalai/nn/parallel/layers/cache_embedding/__init__.py
colossalai/nn/parallel/layers/cache_embedding/__init__.py
+6
-2
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+8
-8
colossalai/nn/parallel/layers/cache_embedding/embedding_config.py
...ai/nn/parallel/layers/cache_embedding/embedding_config.py
+27
-0
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
...ache_embedding/parallel_freq_aware_embedding_tablewise.py
+24
-176
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise_split_cache.py
...ng/parallel_freq_aware_embedding_tablewise_split_cache.py
+138
-0
tests/test_layers/test_cache_embedding.py
tests/test_layers/test_cache_embedding.py
+29
-18
No files found.
colossalai/nn/parallel/layers/cache_embedding/__init__.py
View file @
64169f3e
...
...
@@ -2,8 +2,12 @@ from .cache_mgr import CachedParamMgr, EvictionStrategy
from
.copyer
import
LimitBuffIndexCopyer
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
.parallel_freq_aware_embedding
import
ParallelFreqAwareEmbeddingBag
from
.parallel_freq_aware_embedding_tablewise
import
ParallelFreqAwareEmbeddingBagTablewise
,
TablewiseEmbeddingBagConfig
,
ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
from
.embedding_config
import
TablewiseEmbeddingBagConfig
from
.parallel_freq_aware_embedding_tablewise
import
ParallelFreqAwareEmbeddingBagTablewise
from
.parallel_freq_aware_embedding_tablewise_split_cache
import
ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
__all__
=
[
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'FreqAwareEmbeddingBag'
,
'ParallelFreqAwareEmbeddingBag'
,
'EvictionStrategy'
,
'ParallelFreqAwareEmbeddingBagTablewise'
,
'TablewiseEmbeddingBagConfig'
,
'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache'
'EvictionStrategy'
,
'ParallelFreqAwareEmbeddingBagTablewise'
,
'TablewiseEmbeddingBagConfig'
,
'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache'
]
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
View file @
64169f3e
...
...
@@ -293,7 +293,7 @@ class CachedParamMgr(torch.nn.Module):
Returns:
torch.Tensor: indices on the cuda_cached_weight.
"""
with
record_function
(
"(
zhg
) get unique indices"
):
with
record_function
(
"(
pre-id
) get unique indices"
):
ids
=
ids
.
to
(
self
.
_cache_dev
)
cpu_row_idxs
,
repeat_times
=
torch
.
unique
(
self
.
idx_map
.
index_select
(
0
,
ids
),
return_counts
=
True
)
...
...
@@ -303,7 +303,7 @@ class CachedParamMgr(torch.nn.Module):
f
"Please increase cuda_row_num or decrease the training batch size."
self
.
evict_backlist
=
cpu_row_idxs
with
record_function
(
"(
zhg
) get cpu row idxs"
):
with
record_function
(
"(
pre-id
) get cpu row idxs"
):
comm_cpu_row_idxs
=
cpu_row_idxs
[
torch
.
isin
(
cpu_row_idxs
,
self
.
cached_idx_map
,
invert
=
True
)]
self
.
num_hits_history
.
append
(
len
(
cpu_row_idxs
)
-
len
(
comm_cpu_row_idxs
))
...
...
@@ -311,16 +311,16 @@ class CachedParamMgr(torch.nn.Module):
self
.
num_write_back_history
.
append
(
0
)
# move sure the cuda rows will not be evicted!
with
record_function
(
"(
zhg
) cache update"
):
with
record_function
(
"(
pre-id
) cache update"
):
self
.
_prepare_rows_on_cuda
(
comm_cpu_row_idxs
)
self
.
evict_backlist
=
torch
.
tensor
([],
device
=
cpu_row_idxs
.
device
,
dtype
=
cpu_row_idxs
.
dtype
)
with
record_function
(
"(
zhg
) embed cpu rows idx -> cache gpu row idxs"
):
with
record_function
(
"(
pre-id
) embed cpu rows idx -> cache gpu row idxs"
):
gpu_row_idxs
=
self
.
_id_to_cached_cuda_id
(
ids
)
# update for LFU.
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
with
record_function
(
"(pre-id) lfu cnter updates"
):
unique_gpu_row_idxs
=
self
.
inverted_cached_idx
[
cpu_row_idxs
]
self
.
freq_cnter
.
scatter_add_
(
0
,
unique_gpu_row_idxs
,
repeat_times
)
...
...
colossalai/nn/parallel/layers/cache_embedding/embedding_config.py
0 → 100644
View file @
64169f3e
import
torch
class
TablewiseEmbeddingBagConfig
:
'''
example:
def prepare_tablewise_config(args, cache_ratio, ...):
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
...
return embedding_bag_config_list
'''
def
__init__
(
self
,
num_embeddings
:
int
,
cuda_row_num
:
int
,
assigned_rank
:
int
=
0
,
buffer_size
=
50_000
,
ids_freq_mapping
=
None
,
initial_weight
:
torch
.
tensor
=
None
,
name
:
str
=
""
):
self
.
num_embeddings
=
num_embeddings
self
.
cuda_row_num
=
cuda_row_num
self
.
assigned_rank
=
assigned_rank
self
.
buffer_size
=
buffer_size
self
.
ids_freq_mapping
=
ids_freq_mapping
self
.
initial_weight
=
initial_weight
self
.
name
=
name
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
View file @
64169f3e
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch.profiler
import
record_function
from
typing
import
List
import
abc
import
torch.nn.functional
as
F
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
colossalai.tensor
import
ProcessGroup
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
.cache_mgr
import
EvictionStrategy
from
.embedding_config
import
TablewiseEmbeddingBagConfig
from
colossalai.tensor
import
ProcessGroup
from
colossalai.nn._ops._utils
import
dual_all_to_all_tablewise
class
TablewiseEmbeddingBagConfig
:
'''
example:
def prepare_tablewise_config(args, cache_ratio, ...):
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
...
return embedding_bag_config_list
'''
def
__init__
(
self
,
num_embeddings
:
int
,
cuda_row_num
:
int
,
assigned_rank
:
int
=
0
,
buffer_size
=
50_000
,
ids_freq_mapping
=
None
,
initial_weight
:
torch
.
tensor
=
None
,
name
:
str
=
""
):
self
.
num_embeddings
=
num_embeddings
self
.
cuda_row_num
=
cuda_row_num
self
.
assigned_rank
=
assigned_rank
self
.
buffer_size
=
buffer_size
self
.
ids_freq_mapping
=
ids_freq_mapping
self
.
initial_weight
=
initial_weight
self
.
name
=
name
from
typing
import
List
class
ParallelFreqAwareEmbeddingBagTablewise
(
FreqAwareEmbeddingBag
):
...
...
@@ -44,6 +16,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
all tables assigned to this class instance are managed by a single FreqAwareEmbeddingBag.
Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight.
"""
def
__init__
(
self
,
embedding_bag_config_list
:
List
[
TablewiseEmbeddingBagConfig
],
embedding_dim
:
int
,
...
...
@@ -98,7 +71,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
for
table_i
,
table_num_embeddings
in
enumerate
(
self
.
global_table_num_embeddings_list
):
if
self
.
rank_of_tables
[
table_i
]
==
self
.
rank
:
self
.
idx_offset_list
.
append
(
offset_cumsum
)
else
:
else
:
offset_cumsum
+=
table_num_embeddings
# prepare list shape for all_to_all output
...
...
@@ -122,21 +95,23 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
else
:
indices_end_position
=
offsets
[
batch_size
*
(
handle_table
+
1
)]
# 1. local_indices_list:
local_indices_list
.
append
(
indices
.
narrow
(
0
,
indices_start_position
,
indices_end_position
-
indices_start_position
).
sub
(
self
.
idx_offset_list
[
i
]))
local_indices_list
.
append
(
indices
.
narrow
(
0
,
indices_start_position
,
indices_end_position
-
indices_start_position
).
sub
(
self
.
idx_offset_list
[
i
]))
# 2. local_offsets_list:
if
i
+
1
==
len
(
self
.
assigned_table_list
):
# till-the-end special case
if
not
self
.
include_last_offset
:
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
else
:
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
batch_size
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
else
:
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
local_offsets_list
.
append
(
local_offsets
)
else
:
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
1
).
add
(
offset_pre_end
-
offsets
[
batch_size
*
(
handle_table
)])
offset_pre_end
=
local_offsets
[
-
1
]
local_offsets_list
.
append
(
local_offsets
[:
-
1
])
# 3. local_per_sample_weights_list:
...
...
@@ -154,7 +129,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
local_output
=
F
.
embedding_bag
(
reorder_ids
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
local_offsets
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
local_per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
local_output
=
torch
.
cat
(
local_output
.
split
(
batch_size
),
1
)
local_output
=
torch
.
cat
(
local_output
.
split
(
batch_size
),
1
)
remains
=
batch_size
%
self
.
world_size
scatter_strides
=
[
batch_size
//
self
.
world_size
+
int
(
i
<
remains
)
for
i
in
range
(
self
.
world_size
)]
...
...
@@ -168,130 +143,3 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
def
element_size
(
self
):
return
self
.
weight
.
element_size
()
class
ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
(
abc
.
ABC
,
nn
.
Module
):
"""
every table assigned to this class instance is managed by a FreqAwareEmbeddingBag.
"""
def
__init__
(
self
,
embedding_bag_config_list
:
List
[
TablewiseEmbeddingBagConfig
],
embedding_dim
:
int
,
padding_idx
=
None
,
max_norm
=
None
,
norm_type
=
2.
,
scale_grad_by_freq
=
False
,
sparse
=
False
,
mode
=
'mean'
,
include_last_offset
=
False
,
dtype
=
None
,
device
=
None
,
warmup_ratio
=
0.7
,
pin_weight
=
False
,
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
LFU
):
super
(
ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
,
self
).
__init__
()
self
.
rank
=
dist
.
get_rank
()
self
.
world_size
=
dist
.
get_world_size
()
self
.
rank_of_tables
=
[
config
.
assigned_rank
for
config
in
embedding_bag_config_list
]
self
.
global_table_num_embeddings_list
=
[
config
.
num_embeddings
for
config
in
embedding_bag_config_list
]
self
.
global_tables_num
=
len
(
embedding_bag_config_list
)
self
.
global_tables_offsets
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
self
.
global_table_num_embeddings_list
),
0
).
cuda
()
self
.
assigned_table_list
:
List
[
int
]
=
[]
for
i
,
rank
in
enumerate
(
self
.
rank_of_tables
):
if
rank
==
self
.
rank
:
self
.
assigned_table_list
.
append
(
i
)
self
.
include_last_offset
=
include_last_offset
self
.
pg
=
ProcessGroup
(
tp_degree
=
self
.
world_size
)
# prepare FreqAwareEmbeddingBag list
self
.
freq_aware_embedding_bag_list
:
nn
.
ModuleList
=
nn
.
ModuleList
()
for
config
in
embedding_bag_config_list
:
if
config
.
assigned_rank
!=
self
.
rank
:
continue
self
.
freq_aware_embedding_bag_list
.
append
(
FreqAwareEmbeddingBag
(
num_embeddings
=
config
.
num_embeddings
,
embedding_dim
=
embedding_dim
,
padding_idx
=
padding_idx
,
max_norm
=
max_norm
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
,
_weight
=
config
.
initial_weight
,
mode
=
mode
,
include_last_offset
=
include_last_offset
,
dtype
=
dtype
,
device
=
device
,
cuda_row_num
=
config
.
cuda_row_num
,
ids_freq_mapping
=
config
.
ids_freq_mapping
,
warmup_ratio
=
warmup_ratio
,
buffer_size
=
config
.
buffer_size
,
pin_weight
=
pin_weight
,
evict_strategy
=
evict_strategy
))
# prepare list shape for all_to_all output
self
.
embedding_dim_per_rank
=
[
0
for
i
in
range
(
self
.
world_size
)]
for
rank
in
self
.
rank_of_tables
:
self
.
embedding_dim_per_rank
[
rank
]
+=
embedding_dim
def
forward
(
self
,
indices
:
torch
.
Tensor
,
offsets
:
torch
.
Tensor
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
):
# determine indices to handle
batch_size
=
(
offsets
.
shape
[
0
])
//
self
.
global_tables_num
local_output_list
=
[]
for
i
,
handle_table
in
enumerate
(
self
.
assigned_table_list
):
with
record_function
(
"(tablewise) prepare indices and offsets"
):
with
record_function
(
"part 1"
):
indices_start_position
=
offsets
[
batch_size
*
handle_table
]
if
(
not
self
.
include_last_offset
)
and
(
batch_size
*
(
handle_table
+
1
)
>=
indices
.
shape
[
0
]):
# till the end special case
indices_end_position
=
indices
.
shape
[
0
]
else
:
indices_end_position
=
offsets
[
batch_size
*
(
handle_table
+
1
)]
with
record_function
(
"part 2"
):
# local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table]
local_indices
=
indices
.
narrow
(
0
,
indices_start_position
,
indices_end_position
-
indices_start_position
).
sub
(
self
.
global_tables_offsets
[
handle_table
])
if
self
.
include_last_offset
:
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)]
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
1
).
sub
(
offsets
[
batch_size
*
(
handle_table
)])
else
:
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)]
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
).
sub
(
offsets
[
batch_size
*
(
handle_table
)])
local_per_sample_weights
=
None
if
per_sample_weights
!=
None
:
local_per_sample_weights
=
per_sample_weights
[
indices_start_position
:
indices_end_position
]
with
record_function
(
"(tablewise) tablewise forward"
):
local_output_list
.
append
(
self
.
freq_aware_embedding_bag_list
[
i
](
local_indices
,
local_offsets
,
local_per_sample_weights
))
# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
local_output
=
torch
.
cat
(
local_output_list
,
1
)
# then concatenate those local_output on the second demension.
# use all_to_all
remains
=
batch_size
%
self
.
world_size
scatter_strides
=
[
batch_size
//
self
.
world_size
+
int
(
i
<
remains
)
for
i
in
range
(
self
.
world_size
)]
output_full
=
dual_all_to_all_tablewise
(
local_output
,
self
.
pg
,
scatter_strides
,
self
.
embedding_dim_per_rank
)
if
shape_hook
is
not
None
:
output_full
=
shape_hook
(
output_full
)
return
output_full
def
element_size
(
self
):
if
len
(
self
.
assigned_table_list
)
==
0
:
return
0
return
self
.
freq_aware_embedding_bag_list
[
0
].
cache_weight_mgr
.
weight
.
element_size
()
def
print_comm_stats_
(
self
):
cuda_to_cpu_elem_num
=
0
cpu_to_cuda_elem_num
=
0
for
freq_aware_embedding_bag
in
self
.
freq_aware_embedding_bag_list
:
cuda_to_cpu_elem_num
+=
freq_aware_embedding_bag
.
cache_weight_mgr
.
_cuda_to_cpu_numel
cpu_to_cuda_elem_num
+=
freq_aware_embedding_bag
.
cache_weight_mgr
.
_cpu_to_cuda_numel
print
(
f
"CUDA->CPU num:
{
cuda_to_cpu_elem_num
/
1e6
}
M elem"
)
print
(
f
"CPU->CUDA num:
{
cpu_to_cuda_elem_num
/
1e6
}
M elem"
)
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise_split_cache.py
0 → 100644
View file @
64169f3e
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch.profiler
import
record_function
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
colossalai.tensor
import
ProcessGroup
from
colossalai.nn._ops._utils
import
dual_all_to_all_tablewise
from
.embedding_config
import
TablewiseEmbeddingBagConfig
from
.cache_mgr
import
EvictionStrategy
from
typing
import
List
import
abc
class
ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
(
abc
.
ABC
,
nn
.
Module
):
"""
every table assigned to this class instance is managed by a FreqAwareEmbeddingBag.
"""
def
__init__
(
self
,
embedding_bag_config_list
:
List
[
TablewiseEmbeddingBagConfig
],
embedding_dim
:
int
,
padding_idx
=
None
,
max_norm
=
None
,
norm_type
=
2.
,
scale_grad_by_freq
=
False
,
sparse
=
False
,
mode
=
'mean'
,
include_last_offset
=
False
,
dtype
=
None
,
device
=
None
,
warmup_ratio
=
0.7
,
pin_weight
=
False
,
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
LFU
):
super
(
ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
,
self
).
__init__
()
self
.
rank
=
dist
.
get_rank
()
self
.
world_size
=
dist
.
get_world_size
()
self
.
rank_of_tables
=
[
config
.
assigned_rank
for
config
in
embedding_bag_config_list
]
self
.
global_table_num_embeddings_list
=
[
config
.
num_embeddings
for
config
in
embedding_bag_config_list
]
self
.
global_tables_num
=
len
(
embedding_bag_config_list
)
self
.
global_tables_offsets
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
self
.
global_table_num_embeddings_list
),
0
).
cuda
()
self
.
assigned_table_list
:
List
[
int
]
=
[]
for
i
,
rank
in
enumerate
(
self
.
rank_of_tables
):
if
rank
==
self
.
rank
:
self
.
assigned_table_list
.
append
(
i
)
self
.
include_last_offset
=
include_last_offset
self
.
pg
=
ProcessGroup
(
tp_degree
=
self
.
world_size
)
# prepare FreqAwareEmbeddingBag list
self
.
freq_aware_embedding_bag_list
:
nn
.
ModuleList
=
nn
.
ModuleList
()
for
config
in
embedding_bag_config_list
:
if
config
.
assigned_rank
!=
self
.
rank
:
continue
self
.
freq_aware_embedding_bag_list
.
append
(
FreqAwareEmbeddingBag
(
num_embeddings
=
config
.
num_embeddings
,
embedding_dim
=
embedding_dim
,
padding_idx
=
padding_idx
,
max_norm
=
max_norm
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
,
_weight
=
config
.
initial_weight
,
mode
=
mode
,
include_last_offset
=
include_last_offset
,
dtype
=
dtype
,
device
=
device
,
cuda_row_num
=
config
.
cuda_row_num
,
ids_freq_mapping
=
config
.
ids_freq_mapping
,
warmup_ratio
=
warmup_ratio
,
buffer_size
=
config
.
buffer_size
,
pin_weight
=
pin_weight
,
evict_strategy
=
evict_strategy
))
# prepare list shape for all_to_all output
self
.
embedding_dim_per_rank
=
[
0
for
i
in
range
(
self
.
world_size
)]
for
rank
in
self
.
rank_of_tables
:
self
.
embedding_dim_per_rank
[
rank
]
+=
embedding_dim
def
forward
(
self
,
indices
:
torch
.
Tensor
,
offsets
:
torch
.
Tensor
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
):
# determine indices to handle
batch_size
=
(
offsets
.
shape
[
0
])
//
self
.
global_tables_num
local_output_list
=
[]
for
i
,
handle_table
in
enumerate
(
self
.
assigned_table_list
):
with
record_function
(
"(tablewise) prepare indices and offsets"
):
with
record_function
(
"part 1"
):
indices_start_position
=
offsets
[
batch_size
*
handle_table
]
if
(
not
self
.
include_last_offset
)
and
(
batch_size
*
(
handle_table
+
1
)
>=
indices
.
shape
[
0
]):
# till the end special case
indices_end_position
=
indices
.
shape
[
0
]
else
:
indices_end_position
=
offsets
[
batch_size
*
(
handle_table
+
1
)]
with
record_function
(
"part 2"
):
# local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table]
local_indices
=
indices
.
narrow
(
0
,
indices_start_position
,
indices_end_position
-
indices_start_position
).
sub
(
self
.
global_tables_offsets
[
handle_table
])
if
self
.
include_last_offset
:
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)]
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
+
1
).
sub
(
offsets
[
batch_size
*
(
handle_table
)])
else
:
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)]
local_offsets
=
offsets
.
narrow
(
0
,
batch_size
*
handle_table
,
batch_size
).
sub
(
offsets
[
batch_size
*
(
handle_table
)])
local_per_sample_weights
=
None
if
per_sample_weights
!=
None
:
local_per_sample_weights
=
per_sample_weights
[
indices_start_position
:
indices_end_position
]
with
record_function
(
"(tablewise) tablewise forward"
):
local_output_list
.
append
(
self
.
freq_aware_embedding_bag_list
[
i
](
local_indices
,
local_offsets
,
local_per_sample_weights
))
# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
local_output
=
torch
.
cat
(
local_output_list
,
1
)
# then concatenate those local_output on the second demension.
# use all_to_all
remains
=
batch_size
%
self
.
world_size
scatter_strides
=
[
batch_size
//
self
.
world_size
+
int
(
i
<
remains
)
for
i
in
range
(
self
.
world_size
)]
output_full
=
dual_all_to_all_tablewise
(
local_output
,
self
.
pg
,
scatter_strides
,
self
.
embedding_dim_per_rank
)
if
shape_hook
is
not
None
:
output_full
=
shape_hook
(
output_full
)
return
output_full
def
element_size
(
self
):
if
len
(
self
.
assigned_table_list
)
==
0
:
return
0
return
self
.
freq_aware_embedding_bag_list
[
0
].
cache_weight_mgr
.
weight
.
element_size
()
def
print_comm_stats_
(
self
):
cuda_to_cpu_elem_num
=
0
cpu_to_cuda_elem_num
=
0
for
freq_aware_embedding_bag
in
self
.
freq_aware_embedding_bag_list
:
cuda_to_cpu_elem_num
+=
freq_aware_embedding_bag
.
cache_weight_mgr
.
_cuda_to_cpu_numel
cpu_to_cuda_elem_num
+=
freq_aware_embedding_bag
.
cache_weight_mgr
.
_cpu_to_cuda_numel
print
(
f
"CUDA->CPU num:
{
cuda_to_cpu_elem_num
/
1e6
}
M elem"
)
print
(
f
"CPU->CUDA num:
{
cpu_to_cuda_elem_num
/
1e6
}
M elem"
)
tests/test_layers/test_cache_embedding.py
View file @
64169f3e
...
...
@@ -13,7 +13,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.tensor
import
ColoParameter
,
ProcessGroup
,
ShardSpec
,
ComputePattern
,
ComputeSpec
,
\
ColoTensor
,
ColoTensorSpec
from
colossalai.nn.parallel.layers
import
CachedParamMgr
,
FreqAwareEmbeddingBag
,
ParallelFreqAwareEmbeddingBag
,
EvictionStrategy
,
\
ParallelFreqAwareEmbeddingBagTablewise
,
TablewiseEmbeddingBagConfig
,
ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
ParallelFreqAwareEmbeddingBagTablewise
,
TablewiseEmbeddingBagConfig
from
typing
import
List
NUM_EMBED
,
EMBED_DIM
=
10
,
8
...
...
@@ -209,19 +209,28 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
# initialize weight
# 3 feature tables. idx: 0~5, 6~10, 11~17
weight_tables
=
torch
.
rand
(
18
,
5
)
weight_tables
=
torch
.
rand
(
18
,
5
)
weight_table1
=
weight_tables
[
0
:
6
]
weight_table2
=
weight_tables
[
6
:
11
]
weight_table3
=
weight_tables
[
11
:
18
]
embedding_bag_config_list
:
List
[
TablewiseEmbeddingBagConfig
]
=
[]
embedding_bag_config_list
.
append
(
TablewiseEmbeddingBagConfig
(
num_embeddings
=
6
,
cuda_row_num
=
4
,
assigned_rank
=
0
,
initial_weight
=
weight_table1
.
clone
().
detach
().
cpu
()))
embedding_bag_config_list
.
append
(
TablewiseEmbeddingBagConfig
(
num_embeddings
=
5
,
cuda_row_num
=
4
,
assigned_rank
=
0
,
initial_weight
=
weight_table2
.
clone
().
detach
().
cpu
()))
embedding_bag_config_list
.
append
(
TablewiseEmbeddingBagConfig
(
num_embeddings
=
7
,
cuda_row_num
=
4
,
assigned_rank
=
1
,
initial_weight
=
weight_table3
.
clone
().
detach
().
cpu
()))
embedding_bag_config_list
.
append
(
TablewiseEmbeddingBagConfig
(
num_embeddings
=
6
,
cuda_row_num
=
4
,
assigned_rank
=
0
,
initial_weight
=
weight_table1
.
clone
().
detach
().
cpu
()))
embedding_bag_config_list
.
append
(
TablewiseEmbeddingBagConfig
(
num_embeddings
=
5
,
cuda_row_num
=
4
,
assigned_rank
=
0
,
initial_weight
=
weight_table2
.
clone
().
detach
().
cpu
()))
embedding_bag_config_list
.
append
(
TablewiseEmbeddingBagConfig
(
num_embeddings
=
7
,
cuda_row_num
=
4
,
assigned_rank
=
1
,
initial_weight
=
weight_table3
.
clone
().
detach
().
cpu
()))
if
rank
==
0
:
_weight
=
torch
.
cat
([
weight_table1
,
weight_table2
],
0
)
_weight
=
torch
.
cat
([
weight_table1
,
weight_table2
],
0
)
else
:
_weight
=
weight_table3
model
=
ParallelFreqAwareEmbeddingBagTablewise
(
...
...
@@ -249,7 +258,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
rand_grad
=
torch
.
rand
(
3
,
5
*
3
,
dtype
=
res
.
dtype
,
device
=
res
.
device
)
if
rank
==
0
:
fake_grad
=
rand_grad
[
0
:
2
]
else
:
else
:
fake_grad
=
rand_grad
[
2
:]
res
.
backward
(
fake_grad
)
optimizer
.
step
()
...
...
@@ -261,7 +270,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
include_last_offset
=
True
,
freeze
=
False
).
to
(
device
)
ref_optimizer
=
torch
.
optim
.
SGD
(
ref_model
.
parameters
(),
lr
=
1e-2
)
ref_fake_grad
=
torch
.
cat
(
rand_grad
.
split
(
5
,
1
),
0
)
ref_fake_grad
=
torch
.
cat
(
rand_grad
.
split
(
5
,
1
),
0
)
ref_res
=
ref_model
(
torch
.
tensor
([
1
,
2
,
3
,
1
,
5
,
6
,
7
,
9
,
6
,
8
,
13
,
15
,
11
],
device
=
device
),
torch
.
tensor
([
0
,
3
,
3
,
5
,
7
,
8
,
10
,
10
,
12
,
13
],
device
=
device
))
ref_res
.
backward
(
ref_fake_grad
)
...
...
@@ -273,6 +282,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
ref_weight
=
ref_model
.
weight
.
detach
()[:
11
]
assert
torch
.
allclose
(
recover_weight
,
ref_weight
),
f
"
{
recover_weight
-
ref_weight
}
"
def
run_parallel_freq_aware_embed_columnwise
(
rank
,
world_size
):
device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
())
...
...
@@ -289,7 +299,8 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size):
coloweight
.
set_process_group
(
ProcessGroup
(
tp_degree
=
world_size
))
coloweight
.
set_tensor_spec
(
ShardSpec
(
dims
=
[
-
1
],
num_partitions
=
[
world_size
]),
ComputeSpec
(
ComputePattern
.
TP1D
))
model
=
ParallelFreqAwareEmbeddingBag
.
from_pretrained
(
coloweight
,
model
=
ParallelFreqAwareEmbeddingBag
.
from_pretrained
(
coloweight
,
include_last_offset
=
True
,
freeze
=
False
,
cuda_row_num
=
batch_size
*
2
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment