Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9a9ef653
Unverified
Commit
9a9ef653
authored
Aug 30, 2022
by
Jiarui Fang
Committed by
GitHub
Aug 30, 2022
Browse files
[FAW] cpu caching operations (#1520)
parent
481aecb0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
88 additions
and
66 deletions
+88
-66
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+55
-33
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py
...n/parallel/layers/cache_embedding/freq_aware_embedding.py
+3
-3
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
...l/layers/cache_embedding/parallel_freq_aware_embedding.py
+23
-24
tests/test_layers/test_cache_embedding.py
tests/test_layers/test_cache_embedding.py
+7
-6
No files found.
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
View file @
9a9ef653
...
@@ -30,6 +30,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -30,6 +30,7 @@ class CachedParamMgr(torch.nn.Module):
`EvictionStrategy.LFU`: use the least frequently used cache.
`EvictionStrategy.LFU`: use the least frequently used cache.
`EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume.
`EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume.
Defaults to EvictionStrategy.DATASET.
Defaults to EvictionStrategy.DATASET.
use_cpu_caching (bool, optional): use cpu to execute cache indexing. It is slower than use gpu.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -39,6 +40,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -39,6 +40,7 @@ class CachedParamMgr(torch.nn.Module):
buffer_size
:
int
=
50_000
,
buffer_size
:
int
=
50_000
,
pin_weight
:
bool
=
False
,
pin_weight
:
bool
=
False
,
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
DATASET
,
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
DATASET
,
use_cpu_caching
=
False
,
)
->
None
:
)
->
None
:
super
(
CachedParamMgr
,
self
).
__init__
()
super
(
CachedParamMgr
,
self
).
__init__
()
self
.
buffer_size
=
buffer_size
self
.
buffer_size
=
buffer_size
...
@@ -48,6 +50,13 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -48,6 +50,13 @@ class CachedParamMgr(torch.nn.Module):
self
.
pin_weight
=
pin_weight
self
.
pin_weight
=
pin_weight
self
.
elem_size_in_byte
=
weight
.
element_size
()
self
.
elem_size_in_byte
=
weight
.
element_size
()
self
.
_cpu_caching
=
use_cpu_caching
if
self
.
_cpu_caching
:
self
.
_cache_dev
=
torch
.
device
(
'cpu'
)
else
:
self
.
_cache_dev
=
torch
.
cuda
.
current_device
()
# weight configure
# weight configure
self
.
_init_weight
(
weight
)
self
.
_init_weight
(
weight
)
...
@@ -62,10 +71,15 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -62,10 +71,15 @@ class CachedParamMgr(torch.nn.Module):
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
# cache_row_idx -> frequency, freq of the cache rows.
# cache_row_idx -> frequency, freq of the cache rows.
# classic lfu cache. evict the minimal freq value row in cuda cache.
# classic lfu cache. evict the minimal freq value row in cuda cache.
self
.
register_buffer
(
"freq_cnter"
,
if
self
.
_cpu_caching
:
torch
.
empty
(
self
.
cuda_row_num
,
device
=
torch
.
cuda
.
current_device
(),
self
.
freq_cnter
=
torch
.
empty
(
self
.
cuda_row_num
,
device
=
self
.
_cache_dev
,
dtype
=
torch
.
long
).
fill_
(
sys
.
maxsize
),
dtype
=
torch
.
long
).
fill_
(
sys
.
maxsize
)
persistent
=
False
)
else
:
self
.
register_buffer
(
"freq_cnter"
,
torch
.
empty
(
self
.
cuda_row_num
,
device
=
self
.
_cache_dev
,
dtype
=
torch
.
long
).
fill_
(
sys
.
maxsize
),
persistent
=
False
)
def
_find_evict_gpu_idxs
(
self
,
evict_num
:
int
)
->
torch
.
Tensor
:
def
_find_evict_gpu_idxs
(
self
,
evict_num
:
int
)
->
torch
.
Tensor
:
"""_find_evict_gpu_idxs
"""_find_evict_gpu_idxs
...
@@ -105,26 +119,32 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -105,26 +119,32 @@ class CachedParamMgr(torch.nn.Module):
self
.
weight
=
weight
.
pin_memory
()
if
self
.
pin_weight
else
weight
self
.
weight
=
weight
.
pin_memory
()
if
self
.
pin_weight
else
weight
# map original id to new id with respect to frequency
# map original id to new id with respect to frequency
# id -> cpu_row_idx
# id -> cpu_row_idx
self
.
register_buffer
(
"idx_map"
,
torch
.
arange
(
self
.
num_embeddings
,
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
()),
persistent
=
False
,
)
# cached_idx_map: gpu_row_idx -> cpu_row_idx
self
.
register_buffer
(
"cached_idx_map"
,
torch
.
empty
(
self
.
cuda_row_num
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
long
).
fill_
(
-
1
),
persistent
=
False
)
# cpu_row_id -> gpu_row_idx.
if
self
.
_cpu_caching
:
# gpu_row_idx as -1 means cpu_row_id not in CUDA.
self
.
idx_map
=
torch
.
arange
(
self
.
num_embeddings
,
dtype
=
torch
.
long
,
device
=
self
.
_cache_dev
)
self
.
register_buffer
(
"inverted_cached_idx"
,
self
.
cached_idx_map
=
torch
.
empty
(
self
.
cuda_row_num
,
device
=
self
.
_cache_dev
,
dtype
=
torch
.
long
).
fill_
(
-
1
)
torch
.
zeros
(
self
.
num_embeddings
,
device
=
torch
.
cuda
.
current_device
(),
self
.
inverted_cached_idx
=
torch
.
zeros
(
self
.
num_embeddings
,
device
=
self
.
_cache_dev
,
dtype
=
torch
.
long
).
fill_
(
-
1
),
dtype
=
torch
.
long
).
fill_
(
-
1
)
persistent
=
False
)
else
:
self
.
register_buffer
(
self
.
evict_backlist
=
torch
.
tensor
([],
device
=
torch
.
cuda
.
current_device
())
"idx_map"
,
torch
.
arange
(
self
.
num_embeddings
,
dtype
=
torch
.
long
,
device
=
self
.
_cache_dev
),
persistent
=
False
,
)
# cached_idx_map: gpu_row_idx -> cpu_row_idx
self
.
register_buffer
(
"cached_idx_map"
,
torch
.
empty
(
self
.
cuda_row_num
,
device
=
self
.
_cache_dev
,
dtype
=
torch
.
long
).
fill_
(
-
1
),
persistent
=
False
)
# cpu_row_id -> gpu_row_idx.
# gpu_row_idx as -1 means cpu_row_id not in CUDA.
self
.
register_buffer
(
"inverted_cached_idx"
,
torch
.
zeros
(
self
.
num_embeddings
,
device
=
self
.
_cache_dev
,
dtype
=
torch
.
long
).
fill_
(
-
1
),
persistent
=
False
)
self
.
evict_backlist
=
torch
.
tensor
([],
device
=
self
.
_cache_dev
)
# index copy buffer size should less than 10% of cuda weight.
# index copy buffer size should less than 10% of cuda weight.
if
self
.
buffer_size
>
0
:
if
self
.
buffer_size
>
0
:
...
@@ -191,24 +211,24 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -191,24 +211,24 @@ class CachedParamMgr(torch.nn.Module):
# extract rows from cpu weight
# extract rows from cpu weight
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
and
ids_freq_mapping
is
not
None
:
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
and
ids_freq_mapping
is
not
None
:
freq_value
,
preload_cpu_ids
=
torch
.
topk
(
ids_freq_mapping
,
preload_row_num
,
dim
=
0
,
largest
=
True
)
freq_value
,
preload_cpu_ids
=
torch
.
topk
(
ids_freq_mapping
,
preload_row_num
,
dim
=
0
,
largest
=
True
)
preload_cuda_row_idxs
=
torch
.
arange
(
preload_row_num
).
cuda
(
)
preload_cuda_row_idxs
=
torch
.
arange
(
preload_row_num
).
to
(
self
.
_cache_dev
)
else
:
else
:
preload_cpu_ids
=
torch
.
arange
(
preload_row_num
)
preload_cpu_ids
=
torch
.
arange
(
preload_row_num
)
preload_cuda_row_idxs
=
preload_cpu_ids
.
cuda
(
)
preload_cuda_row_idxs
=
preload_cpu_ids
.
to
(
self
.
_cache_dev
)
if
self
.
buffer_size
>
0
:
if
self
.
buffer_size
>
0
:
self
.
limit_buff_index_copyer
.
index_copy
(
0
,
self
.
limit_buff_index_copyer
.
index_copy
(
0
,
src_index
=
preload_cpu_ids
,
src_index
=
preload_cpu_ids
,
tgt_index
=
preload_cuda_row_idxs
,
tgt_index
=
preload_cuda_row_idxs
.
cuda
()
,
src
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
),
src
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
),
tgt
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
))
tgt
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
))
else
:
else
:
preload_rows
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_select
(
0
,
preload_cpu_ids
).
cuda
()
preload_rows
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_select
(
0
,
preload_cpu_ids
).
cuda
()
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_copy_
(
0
,
preload_cuda_row_idxs
,
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_copy_
(
0
,
preload_cuda_row_idxs
.
cuda
()
,
preload_rows
)
preload_rows
)
# update auxiliary info
# update auxiliary info
self
.
cached_idx_map
[
preload_cuda_row_idxs
]
=
preload_cpu_ids
.
cuda
(
)
self
.
cached_idx_map
[
preload_cuda_row_idxs
]
=
preload_cpu_ids
.
to
(
self
.
_cache_dev
)
self
.
inverted_cached_idx
[
preload_cpu_ids
]
=
preload_cuda_row_idxs
self
.
inverted_cached_idx
[
preload_cpu_ids
]
=
preload_cuda_row_idxs
self
.
_cuda_available_row_num
-=
preload_row_num
self
.
_cuda_available_row_num
-=
preload_row_num
...
@@ -217,7 +237,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -217,7 +237,7 @@ class CachedParamMgr(torch.nn.Module):
if
ids_freq_mapping
is
None
:
if
ids_freq_mapping
is
None
:
self
.
freq_cnter
.
index_fill_
(
0
,
preload_cuda_row_idxs
,
0
)
self
.
freq_cnter
.
index_fill_
(
0
,
preload_cuda_row_idxs
,
0
)
else
:
else
:
self
.
freq_cnter
[
preload_cuda_row_idxs
]
=
freq_value
.
cuda
(
)
self
.
freq_cnter
[
preload_cuda_row_idxs
]
=
freq_value
.
to
(
self
.
_cache_dev
)
print
(
f
'Cache warmup finished cost
{
timer
.
elapsed
}
sec.'
)
print
(
f
'Cache warmup finished cost
{
timer
.
elapsed
}
sec.'
)
...
@@ -227,7 +247,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -227,7 +247,7 @@ class CachedParamMgr(torch.nn.Module):
"""
"""
slots
=
torch
.
nonzero
(
self
.
cached_idx_map
>
-
1
).
squeeze
(
1
)
slots
=
torch
.
nonzero
(
self
.
cached_idx_map
>
-
1
).
squeeze
(
1
)
row_ids
=
self
.
cached_idx_map
[
slots
]
row_ids
=
self
.
cached_idx_map
[
slots
]
rows
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_select
(
0
,
slots
).
cpu
()
rows
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_select
(
0
,
slots
.
cuda
()
).
cpu
()
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_copy_
(
0
,
row_ids
.
cpu
(),
rows
)
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_copy_
(
0
,
row_ids
.
cpu
(),
rows
)
self
.
cached_idx_map
.
index_fill_
(
0
,
slots
,
-
1
)
self
.
cached_idx_map
.
index_fill_
(
0
,
slots
,
-
1
)
self
.
inverted_cached_idx
.
index_fill_
(
0
,
row_ids
,
-
1
)
self
.
inverted_cached_idx
.
index_fill_
(
0
,
row_ids
,
-
1
)
...
@@ -276,6 +296,7 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -276,6 +296,7 @@ class CachedParamMgr(torch.nn.Module):
torch.Tensor: indices on the cuda_cached_weight.
torch.Tensor: indices on the cuda_cached_weight.
"""
"""
with
record_function
(
"(zhg) get unique indices"
):
with
record_function
(
"(zhg) get unique indices"
):
ids
=
ids
.
to
(
self
.
_cache_dev
)
cpu_row_idxs
,
repeat_times
=
torch
.
unique
(
self
.
idx_map
.
index_select
(
0
,
ids
),
return_counts
=
True
)
cpu_row_idxs
,
repeat_times
=
torch
.
unique
(
self
.
idx_map
.
index_select
(
0
,
ids
),
return_counts
=
True
)
assert
len
(
cpu_row_idxs
)
<=
self
.
cuda_row_num
,
\
assert
len
(
cpu_row_idxs
)
<=
self
.
cuda_row_num
,
\
...
@@ -353,7 +374,8 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -353,7 +374,8 @@ class CachedParamMgr(torch.nn.Module):
tgt
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
))
tgt
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
))
else
:
else
:
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
rows
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_select
(
0
,
evict_gpu_row_idxs
).
cpu
()
rows
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_select
(
0
,
evict_gpu_row_idxs
.
cuda
()).
cpu
()
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_copy_
(
0
,
evict_info
.
cpu
(),
rows
)
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_copy_
(
0
,
evict_info
.
cpu
(),
rows
)
self
.
cached_idx_map
.
index_fill_
(
0
,
evict_gpu_row_idxs
,
-
1
)
self
.
cached_idx_map
.
index_fill_
(
0
,
evict_gpu_row_idxs
,
-
1
)
...
@@ -372,12 +394,12 @@ class CachedParamMgr(torch.nn.Module):
...
@@ -372,12 +394,12 @@ class CachedParamMgr(torch.nn.Module):
if
self
.
buffer_size
>
0
:
if
self
.
buffer_size
>
0
:
self
.
limit_buff_index_copyer
.
index_copy
(
0
,
self
.
limit_buff_index_copyer
.
index_copy
(
0
,
src_index
=
cpu_row_idxs
.
cpu
(),
src_index
=
cpu_row_idxs
.
cpu
(),
tgt_index
=
slots
,
tgt_index
=
slots
.
cuda
()
,
src
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
),
src
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
),
tgt
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
))
tgt
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
))
else
:
else
:
rows
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_select
(
0
,
cpu_row_idxs
.
cpu
()).
cuda
()
rows
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_select
(
0
,
cpu_row_idxs
.
cpu
()).
cuda
()
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_copy_
(
0
,
slots
,
rows
)
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_copy_
(
0
,
slots
.
cuda
()
,
rows
)
slot_offsets
=
slots
slot_offsets
=
slots
self
.
cached_idx_map
[
slots
]
=
cpu_row_idxs
self
.
cached_idx_map
[
slots
]
=
cpu_row_idxs
self
.
inverted_cached_idx
.
index_copy_
(
0
,
cpu_row_idxs
,
slot_offsets
)
self
.
inverted_cached_idx
.
index_copy_
(
0
,
cpu_row_idxs
,
slot_offsets
)
...
...
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py
View file @
9a9ef653
...
@@ -74,8 +74,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
...
@@ -74,8 +74,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
indices
)
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
indices
)
embeddings
=
F
.
embedding_bag
(
reorder_ids
,
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
self
.
max_norm
,
embeddings
=
F
.
embedding_bag
(
reorder_ids
.
cuda
()
,
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
if
shape_hook
is
not
None
:
if
shape_hook
is
not
None
:
embeddings
=
shape_hook
(
embeddings
)
embeddings
=
shape_hook
(
embeddings
)
...
@@ -119,4 +119,4 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
...
@@ -119,4 +119,4 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
if
self
.
cache_weight_mgr
.
_cuda_to_cpu_numel
>
0
:
if
self
.
cache_weight_mgr
.
_cuda_to_cpu_numel
>
0
:
return
self
.
cache_weight_mgr
.
_cuda_to_cpu_numel
*
self
.
cache_weight_mgr
.
elem_size_in_byte
/
1e6
/
\
return
self
.
cache_weight_mgr
.
_cuda_to_cpu_numel
*
self
.
cache_weight_mgr
.
elem_size_in_byte
/
1e6
/
\
self
.
cache_weight_mgr
.
_cuda_to_cpu_elapse
self
.
cache_weight_mgr
.
_cuda_to_cpu_elapse
return
0
return
0
\ No newline at end of file
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
View file @
9a9ef653
...
@@ -8,6 +8,7 @@ from colossalai.nn._ops._utils import dual_all_to_all
...
@@ -8,6 +8,7 @@ from colossalai.nn._ops._utils import dual_all_to_all
from
colossalai.tensor
import
ColoParameter
,
ShardSpec
,
ComputePattern
,
ProcessGroup
,
ColoTensorSpec
,
ColoTensor
from
colossalai.tensor
import
ColoParameter
,
ShardSpec
,
ComputePattern
,
ProcessGroup
,
ColoTensorSpec
,
ColoTensor
from
.cache_mgr
import
CachedParamMgr
,
EvictionStrategy
from
.cache_mgr
import
CachedParamMgr
,
EvictionStrategy
def
get_partition
(
embedding_dim
,
rank
,
world_size
)
->
Tuple
[
int
,
int
,
bool
]:
def
get_partition
(
embedding_dim
,
rank
,
world_size
)
->
Tuple
[
int
,
int
,
bool
]:
if
world_size
==
1
:
if
world_size
==
1
:
return
0
,
embedding_dim
,
True
return
0
,
embedding_dim
,
True
...
@@ -29,27 +30,25 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
...
@@ -29,27 +30,25 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
class
ParallelFreqAwareEmbeddingBag
(
FreqAwareEmbeddingBag
):
class
ParallelFreqAwareEmbeddingBag
(
FreqAwareEmbeddingBag
):
def
__init__
(
def
__init__
(
self
,
self
,
num_embeddings
,
num_embeddings
,
embedding_dim
,
embedding_dim
,
padding_idx
=
None
,
padding_idx
=
None
,
max_norm
=
None
,
max_norm
=
None
,
norm_type
=
2.
,
norm_type
=
2.
,
scale_grad_by_freq
=
False
,
scale_grad_by_freq
=
False
,
sparse
=
False
,
sparse
=
False
,
_weight
=
None
,
_weight
=
None
,
mode
=
'mean'
,
mode
=
'mean'
,
include_last_offset
=
False
,
include_last_offset
=
False
,
dtype
=
None
,
dtype
=
None
,
device
=
None
,
device
=
None
,
cuda_row_num
=
0
,
cuda_row_num
=
0
,
ids_freq_mapping
=
None
,
ids_freq_mapping
=
None
,
warmup_ratio
=
0.7
,
warmup_ratio
=
0.7
,
buffer_size
=
50_000
,
buffer_size
=
50_000
,
pin_weight
=
False
,
pin_weight
=
False
,
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
DATASET
):
evict_strategy
:
EvictionStrategy
=
EvictionStrategy
.
DATASET
):
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
world_size
=
torch
.
distributed
.
get_world_size
()
self
.
world_size
=
torch
.
distributed
.
get_world_size
()
...
@@ -60,7 +59,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
...
@@ -60,7 +59,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
super
(
ParallelFreqAwareEmbeddingBag
,
super
(
ParallelFreqAwareEmbeddingBag
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
,
mode
,
include_last_offset
,
dtype
,
device
,
cuda_row_num
,
ids_freq_mapping
,
sparse
,
_weight
,
mode
,
include_last_offset
,
dtype
,
device
,
cuda_row_num
,
ids_freq_mapping
,
warmup_ratio
,
buffer_size
,
pin_weight
,
evict_strategy
)
warmup_ratio
,
buffer_size
,
pin_weight
,
evict_strategy
)
def
_weight_alloc
(
self
,
dtype
,
device
):
def
_weight_alloc
(
self
,
dtype
,
device
):
weight
=
torch
.
empty
(
self
.
num_embeddings
,
self
.
embedding_dim_per_partition
,
device
=
device
,
dtype
=
dtype
)
weight
=
torch
.
empty
(
self
.
num_embeddings
,
self
.
embedding_dim_per_partition
,
device
=
device
,
dtype
=
dtype
)
...
@@ -77,8 +76,8 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
...
@@ -77,8 +76,8 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
indices
)
reorder_ids
=
self
.
cache_weight_mgr
.
prepare_ids
(
indices
)
output_shard
=
F
.
embedding_bag
(
reorder_ids
,
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
self
.
max_norm
,
output_shard
=
F
.
embedding_bag
(
reorder_ids
.
cuda
()
,
self
.
cache_weight_mgr
.
cuda_cached_weight
,
offsets
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
per_sample_weights
,
self
.
include_last_offset
,
self
.
padding_idx
)
if
shape_hook
is
not
None
:
if
shape_hook
is
not
None
:
...
...
tests/test_layers/test_cache_embedding.py
View file @
9a9ef653
...
@@ -83,15 +83,16 @@ def test_reorder_with_freq():
...
@@ -83,15 +83,16 @@ def test_reorder_with_freq():
chunkid
.
append
(
idx
//
chunk_size
)
chunkid
.
append
(
idx
//
chunk_size
)
offset_in_chunk
.
append
(
idx
%
chunk_size
)
offset_in_chunk
.
append
(
idx
%
chunk_size
)
chunkid
=
torch
.
tensor
(
chunkid
,
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
())
dev
=
torch
.
device
(
'cuda'
)
offset_in_chunk
=
torch
.
tensor
(
offset_in_chunk
,
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
())
chunkid
=
torch
.
tensor
(
chunkid
,
dtype
=
torch
.
long
,
device
=
dev
)
offset_in_chunk
=
torch
.
tensor
(
offset_in_chunk
,
dtype
=
torch
.
long
,
device
=
dev
)
weight
=
torch
.
rand
(
num_embed
,
2
)
weight
=
torch
.
rand
(
num_embed
,
2
)
mgr
=
CachedParamMgr
(
weight
,
num_chunk
)
mgr
=
CachedParamMgr
(
weight
,
num_chunk
,
use_cpu_caching
=
dev
.
type
==
'cpu'
)
mgr
.
reorder
(
idx_map
)
mgr
.
reorder
(
idx_map
)
indices
=
mgr
.
idx_map
.
index_select
(
0
,
torch
.
arange
(
num_embed
,
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
()
))
indices
=
mgr
.
idx_map
.
index_select
(
0
,
torch
.
arange
(
num_embed
,
dtype
=
torch
.
long
,
device
=
dev
))
mgr_chunk_id
=
torch
.
div
(
indices
,
chunk_size
,
rounding_mode
=
'floor'
)
mgr_chunk_id
=
torch
.
div
(
indices
,
chunk_size
,
rounding_mode
=
'floor'
)
mgr_offsets
=
torch
.
remainder
(
indices
,
chunk_size
)
mgr_offsets
=
torch
.
remainder
(
indices
,
chunk_size
)
assert
torch
.
allclose
(
chunkid
,
mgr_chunk_id
),
f
"chunk id:
{
chunkid
}
, mgr:
{
mgr_chunk_id
}
"
assert
torch
.
allclose
(
chunkid
,
mgr_chunk_id
),
f
"chunk id:
{
chunkid
}
, mgr:
{
mgr_chunk_id
}
"
...
@@ -280,6 +281,6 @@ def test_parallel_freq_aware_embed(world_size):
...
@@ -280,6 +281,6 @@ def test_parallel_freq_aware_embed(world_size):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
#
test_freq_aware_embed(True)
test_freq_aware_embed
(
True
)
# test_parallel_freq_aware_embed(2)
# test_parallel_freq_aware_embed(2)
test_lfu_strategy
(
False
)
#
test_lfu_strategy(False)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment