Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f3403ff9
"vscode:/vscode.git/clone" did not exist on "43e7d54643f422badb42e2deb33bfd06ccd4ecf9"
Unverified
Commit
f3403ff9
authored
Sep 13, 2022
by
CsRic
Committed by
GitHub
Sep 13, 2022
Browse files
[embeddings] add already_split_along_rank flag for tablewise mode (#1584)
parent
77399dc9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
17 deletions
+41
-17
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
...ache_embedding/parallel_freq_aware_embedding_tablewise.py
+39
-16
tests/test_layers/test_cache_embedding.py
tests/test_layers/test_cache_embedding.py
+2
-1
No files found.
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
View file @
f3403ff9
...
@@ -9,6 +9,7 @@ from colossalai.tensor import ProcessGroup
...
@@ -9,6 +9,7 @@ from colossalai.tensor import ProcessGroup
from
colossalai.nn._ops._utils
import
dual_all_to_all_tablewise
from
colossalai.nn._ops._utils
import
dual_all_to_all_tablewise
from
typing
import
List
from
typing
import
List
import
time
class
ParallelFreqAwareEmbeddingBagTablewise
(
FreqAwareEmbeddingBag
):
class
ParallelFreqAwareEmbeddingBagTablewise
(
FreqAwareEmbeddingBag
):
...
@@ -79,8 +80,43 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
...
@@ -79,8 +80,43 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
for
rank
in
self
.
rank_of_tables
:
for
rank
in
self
.
rank_of_tables
:
self
.
embedding_dim_per_rank
[
rank
]
+=
embedding_dim
self
.
embedding_dim_per_rank
[
rank
]
+=
embedding_dim
def
forward
(
self
,
indices
:
torch
.
Tensor
,
offsets
:
torch
.
Tensor
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
):
def
forward
(
self
,
indices
:
torch
.
Tensor
,
offsets
:
torch
.
Tensor
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
,
already_split_along_rank
=
True
):
if
not
already_split_along_rank
:
# not recommanded. it takes time.
batch_size
=
(
offsets
.
shape
[
0
])
//
self
.
global_tables_num
batch_size
=
(
offsets
.
shape
[
0
])
//
self
.
global_tables_num
local_indices
,
local_offsets
,
local_per_sample_weights
=
self
.
split_along_rank
(
batch_size
,
indices
,
offsets
,
per_sample_weights
)
else
:
# recommanded.
batch_size
=
(
offsets
.
shape
[
0
])
//
len
(
self
.
assigned_table_list
)
local_indices
,
local_offsets
,
local_per_sample_weights
=
indices
,
offsets
,
per_sample_weights
with
torch
.
no_grad
():
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
local_indices
)
local_output
=
F
.
embedding_bag
(
reorder_ids
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
local_offsets
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
local_per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
local_output
=
torch
.
cat
(
local_output
.
split
(
batch_size
),
1
)
remains
=
batch_size
%
self
.
world_size
scatter_strides
=
[
batch_size
//
self
.
world_size
+
int
(
i
<
remains
)
for
i
in
range
(
self
.
world_size
)]
output_full
=
dual_all_to_all_tablewise
(
local_output
,
self
.
pg
,
scatter_strides
,
self
.
embedding_dim_per_rank
)
if
shape_hook
is
not
None
:
output_full
=
shape_hook
(
output_full
)
return
output_full
def
split_along_rank
(
self
,
batch_size
,
indices
:
torch
.
Tensor
,
offsets
:
torch
.
Tensor
=
None
,
per_sample_weights
=
None
):
'''
if input indices and offsets haven't been splitted along assigned rank, this function will do it.
it takes time. please consider splitting data during batch loading.
'''
local_indices_list
:
List
(
torch
.
Tensor
)
=
[]
local_indices_list
:
List
(
torch
.
Tensor
)
=
[]
local_offsets_list
:
List
(
torch
.
Tensor
)
=
[]
local_offsets_list
:
List
(
torch
.
Tensor
)
=
[]
if
per_sample_weights
!=
None
:
if
per_sample_weights
!=
None
:
...
@@ -145,20 +181,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
...
@@ -145,20 +181,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
local_per_sample_weights
=
None
local_per_sample_weights
=
None
if
per_sample_weights
!=
None
:
if
per_sample_weights
!=
None
:
local_per_sample_weights
=
torch
.
cat
(
local_per_sample_weights_list
,
0
)
local_per_sample_weights
=
torch
.
cat
(
local_per_sample_weights_list
,
0
)
with
torch
.
no_grad
():
return
local_indices
,
local_offsets
,
local_per_sample_weights
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
local_indices
)
local_output
=
F
.
embedding_bag
(
reorder_ids
.
cuda
(),
self
.
cache_weight_mgr
.
cuda_cached_weight
,
local_offsets
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
local_per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
local_output
=
torch
.
cat
(
local_output
.
split
(
batch_size
),
1
)
remains
=
batch_size
%
self
.
world_size
scatter_strides
=
[
batch_size
//
self
.
world_size
+
int
(
i
<
remains
)
for
i
in
range
(
self
.
world_size
)]
output_full
=
dual_all_to_all_tablewise
(
local_output
,
self
.
pg
,
scatter_strides
,
self
.
embedding_dim_per_rank
)
if
shape_hook
is
not
None
:
output_full
=
shape_hook
(
output_full
)
return
output_full
def
print_comm_stats_
(
self
):
def
print_comm_stats_
(
self
):
self
.
cache_weight_mgr
.
print_comm_stats
()
self
.
cache_weight_mgr
.
print_comm_stats
()
...
...
tests/test_layers/test_cache_embedding.py
View file @
f3403ff9
...
@@ -253,7 +253,8 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
...
@@ -253,7 +253,8 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
in KJT format
in KJT format
'''
'''
res
=
model
(
torch
.
tensor
([
1
,
2
,
3
,
1
,
5
,
6
,
7
,
9
,
6
,
8
,
13
,
15
,
11
],
device
=
device
),
res
=
model
(
torch
.
tensor
([
1
,
2
,
3
,
1
,
5
,
6
,
7
,
9
,
6
,
8
,
13
,
15
,
11
],
device
=
device
),
torch
.
tensor
([
0
,
3
,
3
,
5
,
7
,
8
,
10
,
10
,
12
,
13
],
device
=
device
))
torch
.
tensor
([
0
,
3
,
3
,
5
,
7
,
8
,
10
,
10
,
12
,
13
],
device
=
device
),
already_split_along_rank
=
False
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
1e-2
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
1e-2
)
rand_grad
=
torch
.
rand
(
3
,
5
*
3
,
dtype
=
res
.
dtype
,
device
=
res
.
device
)
rand_grad
=
torch
.
rand
(
3
,
5
*
3
,
dtype
=
res
.
dtype
,
device
=
res
.
device
)
if
rank
==
0
:
if
rank
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment