Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cb98cf55
Unverified
Commit
cb98cf55
authored
Aug 10, 2022
by
Jiarui Fang
Committed by
GitHub
Aug 10, 2022
Browse files
[FAW] parallel FreqAwareEmbedding (#1424)
parent
0d212183
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
272 additions
and
2 deletions
+272
-2
colossalai/nn/_ops/_utils.py
colossalai/nn/_ops/_utils.py
+36
-0
colossalai/nn/_ops/cache_embedding/__init__.py
colossalai/nn/_ops/cache_embedding/__init__.py
+2
-1
colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py
.../nn/_ops/cache_embedding/parallel_freq_aware_embedding.py
+136
-0
tests/test_tensor/ops/test_cache_embedding.py
tests/test_tensor/ops/test_cache_embedding.py
+98
-1
No files found.
colossalai/nn/_ops/_utils.py
View file @
cb98cf55
...
@@ -195,3 +195,39 @@ def split_forward_gather_backward(input_, process_group, dim):
...
@@ -195,3 +195,39 @@ def split_forward_gather_backward(input_, process_group, dim):
def
gather_forward_split_backward
(
input_
,
process_group
,
dim
):
def
gather_forward_split_backward
(
input_
,
process_group
,
dim
):
return
_GatherForwardSplitBackward
.
apply
(
input_
,
process_group
,
dim
)
return
_GatherForwardSplitBackward
.
apply
(
input_
,
process_group
,
dim
)
def
_all_to_all
(
x
:
torch
.
Tensor
,
pg
:
ProcessGroup
,
scatter_dim
:
int
,
gather_dim
:
int
)
->
torch
.
Tensor
:
world_size
=
pg
.
tp_world_size
()
if
world_size
==
1
:
return
x
# TODO: enabling mpi backend to support CPU all_to_all
assert
x
.
device
.
type
==
'cuda'
,
f
"Currently, the collective function dual_all_to_all only supports nccl backend"
shapes
=
list
(
x
.
size
())
shapes
[
scatter_dim
]
=
shapes
[
scatter_dim
]
//
world_size
scatter_list
=
[
each
.
contiguous
()
for
each
in
torch
.
tensor_split
(
x
,
world_size
,
scatter_dim
)]
gather_list
=
[
torch
.
empty
(
*
shapes
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_to_all
(
gather_list
,
scatter_list
,
group
=
pg
.
tp_process_group
())
return
torch
.
cat
(
gather_list
,
dim
=
gather_dim
).
contiguous
()
class
_DualAllToAll
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
pg
,
scatter_dim
,
gather_dim
):
ctx
.
scatter_dim
=
scatter_dim
ctx
.
gather_dim
=
gather_dim
ctx
.
pg
=
pg
return
_all_to_all
(
x
,
pg
,
scatter_dim
,
gather_dim
)
@
staticmethod
def
backward
(
ctx
,
grad
):
return
_all_to_all
(
grad
,
ctx
.
pg
,
ctx
.
gather_dim
,
ctx
.
scatter_dim
),
None
,
None
,
None
def
dual_all_to_all
(
x
,
pg
,
scatter_dim
:
int
,
gather_dim
:
int
):
return
_DualAllToAll
.
apply
(
x
,
pg
,
scatter_dim
,
gather_dim
)
colossalai/nn/_ops/cache_embedding/__init__.py
View file @
cb98cf55
from
.cache_mgr
import
CachedParamMgr
from
.cache_mgr
import
CachedParamMgr
from
.copyer
import
LimitBuffIndexCopyer
from
.copyer
import
LimitBuffIndexCopyer
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
.parallel_freq_aware_embedding
import
ParallelFreqAwareEmbeddingBag
__all__
=
[
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'FreqAwareEmbeddingBag'
]
__all__
=
[
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'FreqAwareEmbeddingBag'
,
'ParallelFreqAwareEmbeddingBag'
]
colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py
0 → 100644
View file @
cb98cf55
import
torch
import
torch.nn.functional
as
F
from
typing
import
List
,
Optional
,
Iterator
,
Tuple
from
.base_embedding
import
BaseEmbeddingBag
from
.cache_mgr
import
CachedParamMgr
from
torch.nn.parameter
import
Parameter
from
.._utils
import
dual_all_to_all
from
colossalai.tensor
import
ColoParameter
,
ShardSpec
,
ComputeSpec
,
ComputePattern
,
ProcessGroup
def
get_partition
(
embedding_dim
,
rank
,
world_size
)
->
Tuple
[
int
,
int
,
bool
]:
if
world_size
==
1
:
return
0
,
embedding_dim
,
True
assert
embedding_dim
>=
world_size
,
\
f
"Embedding dimension
{
embedding_dim
}
must be larger than the world size "
\
f
"
{
world_size
}
of the process group"
chunk_size
=
embedding_dim
//
world_size
threshold
=
embedding_dim
%
world_size
# if embedding dim is divisible by world size
if
threshold
==
0
:
return
rank
*
chunk_size
,
(
rank
+
1
)
*
chunk_size
,
True
# align with the split strategy of torch.tensor_split
size_list
=
[
chunk_size
+
1
if
i
<
threshold
else
chunk_size
for
i
in
range
(
world_size
)]
offset
=
sum
(
size_list
[:
rank
])
return
offset
,
offset
+
size_list
[
rank
],
False
class
ParallelFreqAwareEmbeddingBag
(
BaseEmbeddingBag
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
padding_idx
=
None
,
max_norm
=
None
,
norm_type
=
2.
,
scale_grad_by_freq
=
False
,
sparse
=
False
,
_weight
=
None
,
mode
=
'mean'
,
include_last_offset
=
False
,
dtype
=
None
,
debug
=
True
):
super
(
ParallelFreqAwareEmbeddingBag
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
mode
,
include_last_offset
)
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
world_size
=
torch
.
distributed
.
get_world_size
()
self
.
debug
=
debug
self
.
partition_start_index
,
self
.
partition_end_index
,
divisible
=
get_partition
(
embedding_dim
,
self
.
rank
,
self
.
world_size
)
self
.
embedding_dim_per_partition
=
self
.
partition_end_index
-
self
.
partition_start_index
if
_weight
is
None
:
self
.
_weight
.
process_group
=
ProcessGroup
(
tp_degree
=
self
.
world_size
)
self
.
_weight
=
ColoParameter
.
from_torch_tensor
(
torch
.
empty
(
self
.
num_embeddings
,
self
.
embedding_dim_per_partition
,
device
=
'cpu'
,
dtype
=
dtype
),
requires_grad
=
True
,
spec
=
ShardSpec
(
dims
=
[
-
1
],
num_partitions
=
[
self
.
world_size
]))
self
.
init_parameters
()
else
:
assert
isinstance
(
_weight
,
ColoParameter
),
"initialized weight must in type of ColoParameter"
_weight
.
process_group
=
ProcessGroup
(
tp_degree
=
self
.
world_size
)
_weight
.
set_tensor_spec
(
ShardSpec
(
dims
=
[
-
1
],
num_partitions
=
[
self
.
world_size
]),
ComputeSpec
(
ComputePattern
.
TP1D
))
self
.
_weight
=
_weight
@
property
def
weight
(
self
):
return
self
.
cache_weight_mgr
.
cpu_weight
def
named_parameters
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
)
->
Iterator
[
Tuple
[
str
,
Parameter
]]:
yield
'weight'
,
self
.
cache_weight_mgr
.
cuda_cached_weight
def
parameters
(
self
,
recurse
:
bool
=
True
)
->
Iterator
[
Parameter
]:
yield
self
.
cache_weight_mgr
.
cuda_cached_weight
@
torch
.
no_grad
()
def
init_parameters
(
self
):
self
.
_weight
.
data
.
uniform_
(
-
1
/
self
.
num_embeddings
,
1
/
self
.
num_embeddings
)
if
self
.
padding_idx
is
not
None
:
self
.
_weight
[
self
.
padding_idx
].
fill_
(
0
)
def
preprocess
(
self
,
cuda_row_num
:
int
,
ids_freq_mapping
:
Optional
[
List
[
int
]]
=
None
,
warmup_ratio
:
float
=
0.7
,
buffer_size
:
int
=
50_000
):
self
.
cache_weight_mgr
=
CachedParamMgr
(
self
.
_weight
,
cuda_row_num
,
buffer_size
=
buffer_size
)
self
.
cache_weight_mgr
.
reorder
(
ids_freq_mapping
,
warmup_ratio
)
def
forward
(
self
,
indices
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
,
scatter_dim
=
0
,
gather_dim
=-
1
):
with
torch
.
no_grad
():
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
indices
)
output_shard
=
F
.
embedding_bag
(
reorder_ids
,
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
if
shape_hook
is
not
None
:
output_shard
=
shape_hook
(
output_shard
)
output_full
=
dual_all_to_all
(
output_shard
,
self
.
_weight
.
get_process_group
(),
scatter_dim
=
scatter_dim
,
gather_dim
=
gather_dim
)
return
output_full
@
classmethod
def
from_pretrained
(
cls
,
embedding
:
torch
.
Tensor
,
freeze
:
bool
=
True
,
padding_idx
:
Optional
[
int
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.
,
scale_grad_by_freq
:
bool
=
False
,
sparse
:
bool
=
False
,
mode
:
str
=
'mean'
,
include_last_offset
:
bool
=
False
,
debug
:
bool
=
True
,
cuda_row_num
:
int
=
100_000
,
ids_freq_mapping
:
Optional
[
List
[
int
]]
=
None
,
warmup_ratio
:
float
=
0.7
)
->
'ParallelFreqAwareEmbeddingBag'
:
rows
,
cols
=
embedding
.
shape
embedding_bag
=
cls
(
rows
,
cols
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
embedding
,
mode
,
include_last_offset
,
debug
)
embedding_bag
.
preprocess
(
cuda_row_num
,
ids_freq_mapping
,
warmup_ratio
)
embedding_bag
.
cache_weight_mgr
.
cuda_cached_weight
.
requires_grad_
=
not
freeze
return
embedding_bag
tests/test_tensor/ops/test_cache_embedding.py
View file @
cb98cf55
...
@@ -3,9 +3,13 @@ from functools import partial
...
@@ -3,9 +3,13 @@ from functools import partial
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
numpy
as
np
import
numpy
as
np
import
random
import
colossalai
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.tensor
import
ColoParameter
from
colossalai.nn._ops.cache_embedding
import
CachedParamMgr
,
FreqAwareEmbeddingBag
,
ParallelFreqAwareEmbeddingBag
from
colossalai.nn._ops.cache_embedding
import
CachedParamMgr
,
FreqAwareEmbeddingBag
from
colossalai.nn._ops.cache_embedding
import
CachedParamMgr
,
FreqAwareEmbeddingBag
...
@@ -13,6 +17,15 @@ NUM_EMBED, EMBED_DIM = 10, 8
...
@@ -13,6 +17,15 @@ NUM_EMBED, EMBED_DIM = 10, 8
BATCH_SIZE
=
8
BATCH_SIZE
=
8
def
set_seed
(
seed
):
"""
To achieve reproducible results, it's necessary to fix random seeds
"""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
def
synthesize_1d_sparse_feature
(
def
synthesize_1d_sparse_feature
(
batch_size
,
batch_size
,
num_embed
,
num_embed
,
...
@@ -128,7 +141,91 @@ def test_freq_aware_embed():
...
@@ -128,7 +141,91 @@ def test_freq_aware_embed():
f
"model weight:
{
model_weight
[
10
:
18
,
:
8
]
}
, reference:
{
ref_weight
[
10
:
18
,
:
8
]
}
"
f
"model weight:
{
model_weight
[
10
:
18
,
:
8
]
}
, reference:
{
ref_weight
[
10
:
18
,
:
8
]
}
"
def
gather_tensor
(
tensor
,
rank
,
world_size
):
gather_list
=
[]
if
rank
==
0
:
gather_list
=
[
torch
.
empty_like
(
tensor
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
gather
(
tensor
,
gather_list
,
dst
=
0
)
return
gather_list
def
run_parallel_freq_aware_embed
(
rank
,
world_size
):
device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
())
num_embed
=
100
embed_dim
=
16
batch_size
=
4
set_seed
(
4321
)
weight
=
torch
.
rand
(
num_embed
,
embed_dim
)
coloweight
=
ColoParameter
(
weight
.
clone
().
detach
().
cpu
(),
requires_grad
=
False
)
model
=
ParallelFreqAwareEmbeddingBag
.
from_pretrained
(
coloweight
,
include_last_offset
=
True
,
freeze
=
False
,
cuda_row_num
=
batch_size
*
2
)
assert
model
.
cache_weight_mgr
.
cpu_weight
.
device
.
type
==
'cpu'
assert
model
.
cache_weight_mgr
.
cuda_cached_weight
.
requires_grad
weight_in_rank
=
torch
.
tensor_split
(
weight
,
world_size
,
-
1
)[
rank
]
assert
torch
.
allclose
(
weight_in_rank
,
model
.
cache_weight_mgr
.
cpu_weight
.
detach
()),
f
"
{
weight_in_rank
-
model
.
cache_weight_mgr
.
cpu_weight
}
"
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
1e-3
)
if
rank
==
0
:
ref_model
=
torch
.
nn
.
EmbeddingBag
.
from_pretrained
(
weight
.
detach
().
clone
(),
include_last_offset
=
True
,
freeze
=
False
).
to
(
device
)
ref_optimizer
=
torch
.
optim
.
SGD
(
ref_model
.
parameters
(),
lr
=
1e-3
)
set_seed
(
4321
)
for
i
in
range
(
5
):
indices
,
offsets
=
synthesize_1d_sparse_feature
(
batch_size
,
num_embed
,
device
)
res
=
model
(
indices
,
offsets
)
grad
=
torch
.
rand
(
batch_size
*
2
,
embed_dim
,
dtype
=
res
.
dtype
,
device
=
res
.
device
)
grad_in_rank
=
torch
.
tensor_split
(
grad
,
world_size
,
0
)[
rank
]
res
.
backward
(
grad_in_rank
)
optimizer
.
step
()
optimizer
.
zero_grad
()
res_list
=
gather_tensor
(
res
.
detach
(),
rank
,
world_size
)
if
rank
==
0
:
ref_res
=
ref_model
(
indices
,
offsets
)
recover_res
=
torch
.
cat
(
res_list
,
dim
=
0
)
assert
torch
.
allclose
(
ref_res
,
recover_res
)
ref_res
.
backward
(
grad
)
ref_optimizer
.
step
()
ref_optimizer
.
zero_grad
()
model
.
cache_weight_mgr
.
flush
()
weight_list
=
gather_tensor
(
model
.
cache_weight_mgr
.
cpu_weight
.
detach
().
cuda
(),
rank
,
world_size
)
if
rank
==
0
:
recover_weight
=
torch
.
cat
(
weight_list
,
dim
=
1
)
assert
torch
.
allclose
(
recover_weight
,
ref_model
.
weight
.
detach
()),
f
"
{
recover_weight
-
ref_model
.
weight
}
"
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_parallel_freq_aware_embed
(
rank
,
world_size
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_parallel_freq_aware_embed
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
# test_freq_aware_embed()
# test_freq_aware_embed()
# test_chunkmgr_admit()
# test_chunkmgr_admit()
pass
test_parallel_freq_aware_embed
(
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment