Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0aad53c6
"docs/source/vscode:/vscode.git/clone" did not exist on "197d0bf4eded9da84bc176323fa9ad075c09d0c0"
Unverified
Commit
0aad53c6
authored
Aug 23, 2022
by
Geng Zhang
Committed by
GitHub
Aug 23, 2022
Browse files
[FCE] update interface for frequency statistics in FreqCacheEmbedding (#1462)
parent
ede32629
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
25 deletions
+30
-25
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+8
-4
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py
...n/parallel/layers/cache_embedding/freq_aware_embedding.py
+9
-8
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
...l/layers/cache_embedding/parallel_freq_aware_embedding.py
+8
-8
tests/test_layers/test_cache_embedding.py
tests/test_layers/test_cache_embedding.py
+5
-5
No files found.
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
View file @
0aad53c6
...
@@ -14,12 +14,17 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -14,12 +14,17 @@ class CachedParamMgr(torch.nn.Module):
During training, GPU needs to transmit rows between CPU and GPU.
During training, GPU needs to transmit rows between CPU and GPU.
"""
"""
def
__init__
(
self
,
weight
:
torch
.
Tensor
,
cuda_row_num
:
int
=
0
,
buffer_size
:
int
=
50_000
)
->
None
:
def
__init__
(
self
,
weight
:
torch
.
Tensor
,
cuda_row_num
:
int
=
0
,
buffer_size
:
int
=
50_000
,
pin_weight
=
False
)
->
None
:
super
(
CachedParamMgr
,
self
).
__init__
()
super
(
CachedParamMgr
,
self
).
__init__
()
self
.
buffer_size
=
buffer_size
self
.
buffer_size
=
buffer_size
self
.
num_embeddings
,
self
.
embedding_dim
=
weight
.
shape
self
.
num_embeddings
,
self
.
embedding_dim
=
weight
.
shape
self
.
cuda_row_num
=
cuda_row_num
self
.
cuda_row_num
=
cuda_row_num
self
.
_cuda_available_row_num
=
self
.
cuda_row_num
self
.
_cuda_available_row_num
=
self
.
cuda_row_num
self
.
pin_weight
=
pin_weight
self
.
elem_size_in_byte
=
weight
.
element_size
()
self
.
elem_size_in_byte
=
weight
.
element_size
()
...
@@ -43,8 +48,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -43,8 +48,7 @@ class CachedParamMgr(torch.nn.Module):
dtype
=
weight
.
dtype
))
dtype
=
weight
.
dtype
))
# pin memory cpu for higher CPU-GPU copy bandwidth
# pin memory cpu for higher CPU-GPU copy bandwidth
self
.
weight
=
weight
.
contiguous
().
cpu
().
pin_memory
()
self
.
weight
=
weight
.
pin_memory
()
if
self
.
pin_weight
else
weight
# map original id to new id with respect to frequency
# map original id to new id with respect to frequency
# id -> cpu_row_idx
# id -> cpu_row_idx
self
.
register_buffer
(
self
.
register_buffer
(
...
@@ -109,7 +113,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -109,7 +113,7 @@ class CachedParamMgr(torch.nn.Module):
warmup_ratio (float): the amount of chunks preloaded in cuda cache
warmup_ratio (float): the amount of chunks preloaded in cuda cache
"""
"""
if
ids_freq_mapping
is
not
None
:
if
ids_freq_mapping
is
not
None
:
tmp_idx
=
torch
.
argsort
(
torch
.
from_numpy
(
ids_freq_mapping
).
cuda
()
,
descending
=
True
)
tmp_idx
=
torch
.
argsort
(
ids_freq_mapping
,
descending
=
True
)
sorted_idx
=
torch
.
argsort
(
tmp_idx
)
sorted_idx
=
torch
.
argsort
(
tmp_idx
)
self
.
idx_map
.
data
.
copy_
(
sorted_idx
)
self
.
idx_map
.
data
.
copy_
(
sorted_idx
)
...
...
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py
View file @
0aad53c6
...
@@ -27,20 +27,19 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
...
@@ -27,20 +27,19 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
ids_freq_mapping
=
None
,
ids_freq_mapping
=
None
,
warmup_ratio
=
0.7
,
warmup_ratio
=
0.7
,
buffer_size
=
50_000
,
buffer_size
=
50_000
,
pin_weight
=
False
,
):
):
super
(
FreqAwareEmbeddingBag
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
super
(
FreqAwareEmbeddingBag
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
mode
,
include_last_offset
)
scale_grad_by_freq
,
sparse
,
mode
,
include_last_offset
)
if
_weight
is
None
:
if
_weight
is
None
:
_weight
=
self
.
_weight_alloc
(
dtype
,
device
)
_weight
=
self
.
_weight_alloc
(
dtype
,
device
)
else
:
_weight
=
_weight
# configure weight & cache
# configure weight & cache
self
.
_preprocess
(
_weight
,
cuda_row_num
,
ids_freq_mapping
,
warmup_ratio
,
buffer_size
)
self
.
_preprocess
(
_weight
,
cuda_row_num
,
ids_freq_mapping
,
warmup_ratio
,
buffer_size
,
pin_weight
)
def
_weight_alloc
(
self
,
dtype
,
device
):
def
_weight_alloc
(
self
,
dtype
,
device
):
weight
=
torch
.
empty
(
self
.
num_embeddings
,
self
.
embedding_dim
,
dtype
=
dtype
,
device
=
device
,
pin_memory
=
True
)
weight
=
torch
.
empty
(
self
.
num_embeddings
,
self
.
embedding_dim
,
dtype
=
dtype
,
device
=
device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
weight
.
data
.
uniform_
(
-
1
/
self
.
num_embeddings
,
1
/
self
.
num_embeddings
)
weight
.
data
.
uniform_
(
-
1
/
self
.
num_embeddings
,
1
/
self
.
num_embeddings
)
if
self
.
padding_idx
is
not
None
:
if
self
.
padding_idx
is
not
None
:
...
@@ -52,7 +51,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
...
@@ -52,7 +51,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
cuda_row_num
:
int
,
cuda_row_num
:
int
,
ids_freq_mapping
:
Optional
[
List
[
int
]]
=
None
,
ids_freq_mapping
:
Optional
[
List
[
int
]]
=
None
,
warmup_ratio
=
0.7
,
warmup_ratio
=
0.7
,
buffer_size
=
50_000
):
buffer_size
=
50_000
,
pin_weight
=
False
):
"""
"""
Called after initialized.
Called after initialized.
Reorder the weight rows according to the ids_freq_mapping.
Reorder the weight rows according to the ids_freq_mapping.
...
@@ -63,17 +63,18 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
...
@@ -63,17 +63,18 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
warmup_ratio (float): the amount of rows preloaded in cuda cache
warmup_ratio (float): the amount of rows preloaded in cuda cache
"""
"""
self
.
cache_weight_mgr
=
CachedParamMgr
(
weight
,
cuda_row_num
,
buffer_size
)
self
.
cache_weight_mgr
=
CachedParamMgr
(
weight
,
cuda_row_num
,
buffer_size
,
pin_weight
)
self
.
cache_weight_mgr
.
reorder
(
ids_freq_mapping
,
warmup_ratio
)
self
.
cache_weight_mgr
.
reorder
(
ids_freq_mapping
,
warmup_ratio
)
def
forward
(
self
,
indices
,
offsets
=
None
,
per_sample_weights
=
None
):
def
forward
(
self
,
indices
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
indices
)
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
indices
)
embeddings
=
F
.
embedding_bag
(
reorder_ids
,
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
self
.
max_norm
,
embeddings
=
F
.
embedding_bag
(
reorder_ids
,
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
if
shape_hook
is
not
None
:
embeddings
=
shape_hook
(
embeddings
)
return
embeddings
return
embeddings
@
property
@
property
...
...
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
View file @
0aad53c6
...
@@ -3,8 +3,6 @@ import torch.nn.functional as F
...
@@ -3,8 +3,6 @@ import torch.nn.functional as F
from
typing
import
List
,
Optional
,
Iterator
,
Tuple
from
typing
import
List
,
Optional
,
Iterator
,
Tuple
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
.freq_aware_embedding
import
FreqAwareEmbeddingBag
from
.cache_mgr
import
CachedParamMgr
from
torch.nn.parameter
import
Parameter
from
colossalai.nn._ops._utils
import
dual_all_to_all
from
colossalai.nn._ops._utils
import
dual_all_to_all
from
colossalai.tensor
import
ColoParameter
,
ShardSpec
,
ComputePattern
,
ProcessGroup
,
ColoTensorSpec
,
ColoTensor
from
colossalai.tensor
import
ColoParameter
,
ShardSpec
,
ComputePattern
,
ProcessGroup
,
ColoTensorSpec
,
ColoTensor
...
@@ -49,6 +47,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
...
@@ -49,6 +47,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
ids_freq_mapping
=
None
,
ids_freq_mapping
=
None
,
warmup_ratio
=
0.7
,
warmup_ratio
=
0.7
,
buffer_size
=
50_000
,
buffer_size
=
50_000
,
pin_weight
=
False
,
):
):
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
world_size
=
torch
.
distributed
.
get_world_size
()
self
.
world_size
=
torch
.
distributed
.
get_world_size
()
...
@@ -60,17 +59,18 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
...
@@ -60,17 +59,18 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
super
(
ParallelFreqAwareEmbeddingBag
,
super
(
ParallelFreqAwareEmbeddingBag
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
,
mode
,
include_last_offset
,
dtype
,
device
,
cuda_row_num
,
ids_freq_mapping
,
sparse
,
_weight
,
mode
,
include_last_offset
,
dtype
,
device
,
cuda_row_num
,
ids_freq_mapping
,
warmup_ratio
,
buffer_size
)
warmup_ratio
,
buffer_size
,
pin_weight
)
def
_weight_alloc
(
self
,
dtype
,
device
):
def
_weight_alloc
(
self
,
dtype
,
device
):
weight
=
torch
.
empty
(
self
.
num_embeddings
,
self
.
embedding_dim_per_partition
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
weight
.
data
.
uniform_
(
-
1
/
self
.
num_embeddings
,
1
/
self
.
num_embeddings
)
if
self
.
padding_idx
is
not
None
:
weight
[
self
.
padding_idx
].
fill_
(
0
)
colo_tensor_spec
=
ColoTensorSpec
(
pg
=
ProcessGroup
(
tp_degree
=
self
.
world_size
),
colo_tensor_spec
=
ColoTensorSpec
(
pg
=
ProcessGroup
(
tp_degree
=
self
.
world_size
),
dist_attr
=
ShardSpec
(
dims
=
[
-
1
],
num_partitions
=
[
self
.
world_size
]),
dist_attr
=
ShardSpec
(
dims
=
[
-
1
],
num_partitions
=
[
self
.
world_size
]),
compute_attr
=
ComputePattern
.
TP1D
)
compute_attr
=
ComputePattern
.
TP1D
)
return
ColoTensor
.
from_torch_tensor
(
torch
.
empty
(
self
.
num_embeddings
,
return
ColoTensor
.
from_torch_tensor
(
weight
,
spec
=
colo_tensor_spec
)
self
.
embedding_dim_per_partition
,
device
=
device
,
dtype
=
dtype
),
spec
=
colo_tensor_spec
)
def
forward
(
self
,
indices
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
,
scatter_dim
=
0
,
gather_dim
=-
1
):
def
forward
(
self
,
indices
,
offsets
=
None
,
per_sample_weights
=
None
,
shape_hook
=
None
,
scatter_dim
=
0
,
gather_dim
=-
1
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
...
tests/test_layers/test_cache_embedding.py
View file @
0aad53c6
...
@@ -44,7 +44,7 @@ def synthesize_1d_sparse_feature(
...
@@ -44,7 +44,7 @@ def synthesize_1d_sparse_feature(
def
test_cachemgr
():
def
test_cachemgr
():
model
=
torch
.
nn
.
EmbeddingBag
(
10000
,
128
)
model
=
torch
.
nn
.
EmbeddingBag
(
10000
,
128
)
# 10 chunks, 5 in cuda
# 10 chunks, 5 in cuda
mgr
=
CachedParamMgr
(
model
.
weight
,
5
)
mgr
=
CachedParamMgr
(
model
.
weight
.
detach
()
,
5
)
assert
mgr
.
cuda_row_num
==
5
assert
mgr
.
cuda_row_num
==
5
mgr
.
_admit
(
1
)
mgr
.
_admit
(
1
)
...
@@ -74,8 +74,8 @@ def test_reorder_with_freq():
...
@@ -74,8 +74,8 @@ def test_reorder_with_freq():
chunk_size
=
1
chunk_size
=
1
num_chunk
=
5
num_chunk
=
5
idx_map
=
np
.
random
.
randint
(
10000
,
size
=
(
num_embed
,))
idx_map
=
torch
.
randint
(
10000
,
size
=
(
num_embed
,))
sorted_idx
=
np
.
flipud
(
np
.
argsort
(
idx_map
)
).
tolist
()
sorted_idx
=
torch
.
argsort
(
idx_map
,
descending
=
True
).
tolist
()
chunkid
,
offset_in_chunk
=
[],
[]
chunkid
,
offset_in_chunk
=
[],
[]
for
i
in
range
(
num_embed
):
for
i
in
range
(
num_embed
):
idx
=
sorted_idx
.
index
(
i
)
idx
=
sorted_idx
.
index
(
i
)
...
@@ -231,6 +231,6 @@ def test_parallel_freq_aware_embed(world_size):
...
@@ -231,6 +231,6 @@ def test_parallel_freq_aware_embed(world_size):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
#
test_cachemgr()
test_cachemgr
()
# test_freq_aware_embed()
# test_freq_aware_embed()
test_parallel_freq_aware_embed
(
2
)
#
test_parallel_freq_aware_embed(2)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment