Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cb98cf55
Unverified
Commit
cb98cf55
authored
Aug 10, 2022
by
Jiarui Fang
Committed by
GitHub
Aug 10, 2022
Browse files
[FAW] parallel FreqAwareEmbedding (#1424)
parent
0d212183
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
272 additions
and
2 deletions
+272
-2
colossalai/nn/_ops/_utils.py
colossalai/nn/_ops/_utils.py
+36
-0
colossalai/nn/_ops/cache_embedding/__init__.py
colossalai/nn/_ops/cache_embedding/__init__.py
+2
-1
colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py
.../nn/_ops/cache_embedding/parallel_freq_aware_embedding.py
+136
-0
tests/test_tensor/ops/test_cache_embedding.py
tests/test_tensor/ops/test_cache_embedding.py
+98
-1
No files found.
colossalai/nn/_ops/_utils.py
View file @
cb98cf55
...
...
@@ -195,3 +195,39 @@ def split_forward_gather_backward(input_, process_group, dim):
def
gather_forward_split_backward
(
input_
,
process_group
,
dim
):
return
_GatherForwardSplitBackward
.
apply
(
input_
,
process_group
,
dim
)
def
_all_to_all
(
x
:
torch
.
Tensor
,
pg
:
ProcessGroup
,
scatter_dim
:
int
,
gather_dim
:
int
)
->
torch
.
Tensor
:
world_size
=
pg
.
tp_world_size
()
if
world_size
==
1
:
return
x
# TODO: enabling mpi backend to support CPU all_to_all
assert
x
.
device
.
type
==
'cuda'
,
f
"Currently, the collective function dual_all_to_all only supports nccl backend"
shapes
=
list
(
x
.
size
())
shapes
[
scatter_dim
]
=
shapes
[
scatter_dim
]
//
world_size
scatter_list
=
[
each
.
contiguous
()
for
each
in
torch
.
tensor_split
(
x
,
world_size
,
scatter_dim
)]
gather_list
=
[
torch
.
empty
(
*
shapes
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_to_all
(
gather_list
,
scatter_list
,
group
=
pg
.
tp_process_group
())
return
torch
.
cat
(
gather_list
,
dim
=
gather_dim
).
contiguous
()
class
_DualAllToAll
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
pg
,
scatter_dim
,
gather_dim
):
ctx
.
scatter_dim
=
scatter_dim
ctx
.
gather_dim
=
gather_dim
ctx
.
pg
=
pg
return
_all_to_all
(
x
,
pg
,
scatter_dim
,
gather_dim
)
@
staticmethod
def
backward
(
ctx
,
grad
):
return
_all_to_all
(
grad
,
ctx
.
pg
,
ctx
.
gather_dim
,
ctx
.
scatter_dim
),
None
,
None
,
None
def
dual_all_to_all
(
x
,
pg
,
scatter_dim
:
int
,
gather_dim
:
int
):
return
_DualAllToAll
.
apply
(
x
,
pg
,
scatter_dim
,
gather_dim
)
colossalai/nn/_ops/cache_embedding/__init__.py
View file @
cb98cf55
from
.cache_mgr
import
CachedParamMgr
from
.copyer
import
LimitBuffIndexCopyer
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
.parallel_freq_aware_embedding
import
ParallelFreqAwareEmbeddingBag
__all__
=
[
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'FreqAwareEmbeddingBag'
]
__all__
=
[
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
,
'FreqAwareEmbeddingBag'
,
'ParallelFreqAwareEmbeddingBag'
]
colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py
0 → 100644
View file @
cb98cf55
import
torch
import
torch.nn.functional
as
F
from
typing
import
List
,
Optional
,
Iterator
,
Tuple
from
.base_embedding
import
BaseEmbeddingBag
from
.cache_mgr
import
CachedParamMgr
from
torch.nn.parameter
import
Parameter
from
.._utils
import
dual_all_to_all
from
colossalai.tensor
import
ColoParameter
,
ShardSpec
,
ComputeSpec
,
ComputePattern
,
ProcessGroup
def
get_partition
(
embedding_dim
,
rank
,
world_size
)
->
Tuple
[
int
,
int
,
bool
]:
if
world_size
==
1
:
return
0
,
embedding_dim
,
True
assert
embedding_dim
>=
world_size
,
\
f
"Embedding dimension
{
embedding_dim
}
must be larger than the world size "
\
f
"
{
world_size
}
of the process group"
chunk_size
=
embedding_dim
//
world_size
threshold
=
embedding_dim
%
world_size
# if embedding dim is divisible by world size
if
threshold
==
0
:
return
rank
*
chunk_size
,
(
rank
+
1
)
*
chunk_size
,
True
# align with the split strategy of torch.tensor_split
size_list
=
[
chunk_size
+
1
if
i
<
threshold
else
chunk_size
for
i
in
range
(
world_size
)]
offset
=
sum
(
size_list
[:
rank
])
return
offset
,
offset
+
size_list
[
rank
],
False
class
ParallelFreqAwareEmbeddingBag
(
BaseEmbeddingBag
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
padding_idx
=
None
,
max_norm
=
None
,
norm_type
=
2.
,
scale_grad_by_freq
=
False
,
sparse
=
False
,
_weight
=
None
,
mode
=
'mean'
,
include_last_offset
=
False
,
dtype
=
None
,
debug
=
True
):
super
(
ParallelFreqAwareEmbeddingBag
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
mode
,
include_last_offset
)
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
world_size
=
torch
.
distributed
.
get_world_size
()
self
.
debug
=
debug
self
.
partition_start_index
,
self
.
partition_end_index
,
divisible
=
get_partition
(
embedding_dim
,
self
.
rank
,
self
.
world_size
)
self
.
embedding_dim_per_partition
=
self
.
partition_end_index
-
self
.
partition_start_index
if
_weight
is
None
:
self
.
_weight
.
process_group
=
ProcessGroup
(
tp_degree
=
self
.
world_size
)
self
.
_weight
=
ColoParameter
.
from_torch_tensor
(
torch
.
empty
(
self
.
num_embeddings
,
self
.
embedding_dim_per_partition
,
device
=
'cpu'
,
dtype
=
dtype
),
requires_grad
=
True
,
spec
=
ShardSpec
(
dims
=
[
-
1
],
num_partitions
=
[
self
.
world_size
]))
self
.
init_parameters
()
else
:
assert
isinstance
(
_weight
,
ColoParameter
),
"initialized weight must in type of ColoParameter"
_weight
.
process_group
=
ProcessGroup
(
tp_degree
=
self
.
world_size
)
_weight
.
set_tensor_spec
(
ShardSpec
(
dims
=
[
-
1
],
num_partitions
=
[
self
.
world_size
]),
ComputeSpec
(
ComputePattern
.
TP1D
))
self
.
_weight
=
_weight
@
property
def
weight
(
self
):
return
self
.
cache_weight_mgr
.
cpu_weight
def
named_parameters
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
)
->
Iterator
[
Tuple
[
str
,
Parameter
]]:
yield
'weight'
,
self
.
cache_weight_mgr
.
cuda_cached_weight
def
parameters
(
self
,
recurse
:
bool
=
True
)
->
Iterator
[
Parameter
]:
yield
self
.
cache_weight_mgr
.
cuda_cached_weight
@
torch
.
no_grad
()
def
init_parameters
(
self
):
self
.
_weight
.
data
.
uniform_
(
-
1
/
self
.
num_embeddings
,
1
/
self
.
num_embeddings
)
if
self
.
padding_idx
is
not
None
:
self
.
_weight
[
self
.
padding_idx
].
fill_
(
0
)
def
preprocess
(
self
,
cuda_row_num
:
int
,
ids_freq_mapping
:
Optional
[
List
[
int
]]
=
None
,
warmup_ratio
:
float
=
0.7
,
buffer_size
:
int
=
50_000
):
self
.
cache_weight_mgr
=
CachedParamMgr
(
self
.
_weight
,
cuda_row_num
,
buffer_size
=
buffer_size
)
self
.
cache_weight_mgr
.
reorder
(
ids_freq_mapping
,
warmup_ratio
)
def
forward
(
self
,
indices
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
,
scatter_dim
=
0
,
gather_dim
=-
1
):
with
torch
.
no_grad
():
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
indices
)
output_shard
=
F
.
embedding_bag
(
reorder_ids
,
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
if
shape_hook
is
not
None
:
output_shard
=
shape_hook
(
output_shard
)
output_full
=
dual_all_to_all
(
output_shard
,
self
.
_weight
.
get_process_group
(),
scatter_dim
=
scatter_dim
,
gather_dim
=
gather_dim
)
return
output_full
@
classmethod
def
from_pretrained
(
cls
,
embedding
:
torch
.
Tensor
,
freeze
:
bool
=
True
,
padding_idx
:
Optional
[
int
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.
,
scale_grad_by_freq
:
bool
=
False
,
sparse
:
bool
=
False
,
mode
:
str
=
'mean'
,
include_last_offset
:
bool
=
False
,
debug
:
bool
=
True
,
cuda_row_num
:
int
=
100_000
,
ids_freq_mapping
:
Optional
[
List
[
int
]]
=
None
,
warmup_ratio
:
float
=
0.7
)
->
'ParallelFreqAwareEmbeddingBag'
:
rows
,
cols
=
embedding
.
shape
embedding_bag
=
cls
(
rows
,
cols
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
embedding
,
mode
,
include_last_offset
,
debug
)
embedding_bag
.
preprocess
(
cuda_row_num
,
ids_freq_mapping
,
warmup_ratio
)
embedding_bag
.
cache_weight_mgr
.
cuda_cached_weight
.
requires_grad_
=
not
freeze
return
embedding_bag
tests/test_tensor/ops/test_cache_embedding.py
View file @
cb98cf55
...
...
@@ -3,9 +3,13 @@ from functools import partial
import
torch
import
torch.multiprocessing
as
mp
import
numpy
as
np
import
random
import
colossalai
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.tensor
import
ColoParameter
from
colossalai.nn._ops.cache_embedding
import
CachedParamMgr
,
FreqAwareEmbeddingBag
,
ParallelFreqAwareEmbeddingBag
from
colossalai.nn._ops.cache_embedding
import
CachedParamMgr
,
FreqAwareEmbeddingBag
...
...
@@ -13,6 +17,15 @@ NUM_EMBED, EMBED_DIM = 10, 8
BATCH_SIZE
=
8
def
set_seed
(
seed
):
"""
To achieve reproducible results, it's necessary to fix random seeds
"""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
def
synthesize_1d_sparse_feature
(
batch_size
,
num_embed
,
...
...
@@ -128,7 +141,91 @@ def test_freq_aware_embed():
f
"model weight:
{
model_weight
[
10
:
18
,
:
8
]
}
, reference:
{
ref_weight
[
10
:
18
,
:
8
]
}
"
def
gather_tensor
(
tensor
,
rank
,
world_size
):
gather_list
=
[]
if
rank
==
0
:
gather_list
=
[
torch
.
empty_like
(
tensor
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
gather
(
tensor
,
gather_list
,
dst
=
0
)
return
gather_list
def
run_parallel_freq_aware_embed
(
rank
,
world_size
):
device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
())
num_embed
=
100
embed_dim
=
16
batch_size
=
4
set_seed
(
4321
)
weight
=
torch
.
rand
(
num_embed
,
embed_dim
)
coloweight
=
ColoParameter
(
weight
.
clone
().
detach
().
cpu
(),
requires_grad
=
False
)
model
=
ParallelFreqAwareEmbeddingBag
.
from_pretrained
(
coloweight
,
include_last_offset
=
True
,
freeze
=
False
,
cuda_row_num
=
batch_size
*
2
)
assert
model
.
cache_weight_mgr
.
cpu_weight
.
device
.
type
==
'cpu'
assert
model
.
cache_weight_mgr
.
cuda_cached_weight
.
requires_grad
weight_in_rank
=
torch
.
tensor_split
(
weight
,
world_size
,
-
1
)[
rank
]
assert
torch
.
allclose
(
weight_in_rank
,
model
.
cache_weight_mgr
.
cpu_weight
.
detach
()),
f
"
{
weight_in_rank
-
model
.
cache_weight_mgr
.
cpu_weight
}
"
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
1e-3
)
if
rank
==
0
:
ref_model
=
torch
.
nn
.
EmbeddingBag
.
from_pretrained
(
weight
.
detach
().
clone
(),
include_last_offset
=
True
,
freeze
=
False
).
to
(
device
)
ref_optimizer
=
torch
.
optim
.
SGD
(
ref_model
.
parameters
(),
lr
=
1e-3
)
set_seed
(
4321
)
for
i
in
range
(
5
):
indices
,
offsets
=
synthesize_1d_sparse_feature
(
batch_size
,
num_embed
,
device
)
res
=
model
(
indices
,
offsets
)
grad
=
torch
.
rand
(
batch_size
*
2
,
embed_dim
,
dtype
=
res
.
dtype
,
device
=
res
.
device
)
grad_in_rank
=
torch
.
tensor_split
(
grad
,
world_size
,
0
)[
rank
]
res
.
backward
(
grad_in_rank
)
optimizer
.
step
()
optimizer
.
zero_grad
()
res_list
=
gather_tensor
(
res
.
detach
(),
rank
,
world_size
)
if
rank
==
0
:
ref_res
=
ref_model
(
indices
,
offsets
)
recover_res
=
torch
.
cat
(
res_list
,
dim
=
0
)
assert
torch
.
allclose
(
ref_res
,
recover_res
)
ref_res
.
backward
(
grad
)
ref_optimizer
.
step
()
ref_optimizer
.
zero_grad
()
model
.
cache_weight_mgr
.
flush
()
weight_list
=
gather_tensor
(
model
.
cache_weight_mgr
.
cpu_weight
.
detach
().
cuda
(),
rank
,
world_size
)
if
rank
==
0
:
recover_weight
=
torch
.
cat
(
weight_list
,
dim
=
1
)
assert
torch
.
allclose
(
recover_weight
,
ref_model
.
weight
.
detach
()),
f
"
{
recover_weight
-
ref_model
.
weight
}
"
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_parallel_freq_aware_embed
(
rank
,
world_size
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_parallel_freq_aware_embed
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
# test_freq_aware_embed()
# test_chunkmgr_admit()
pass
test_parallel_freq_aware_embed
(
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment