Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
504419d2
"icp/vscode:/vscode.git/clone" did not exist on "9322edef4c792425e0b6381b9dda3a3541cddc18"
Unverified
Commit
504419d2
authored
Aug 09, 2022
by
Jiarui Fang
Committed by
GitHub
Aug 09, 2022
Browse files
[FAW] add cache manager for the cached embedding (#1419)
parent
44fd3c83
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
514 additions
and
0 deletions
+514
-0
colossalai/nn/_ops/cache_embedding/__init__.py
colossalai/nn/_ops/cache_embedding/__init__.py
+4
-0
colossalai/nn/_ops/cache_embedding/base_embedding.py
colossalai/nn/_ops/cache_embedding/base_embedding.py
+36
-0
colossalai/nn/_ops/cache_embedding/cache_mgr.py
colossalai/nn/_ops/cache_embedding/cache_mgr.py
+348
-0
colossalai/nn/_ops/cache_embedding/copyer.py
colossalai/nn/_ops/cache_embedding/copyer.py
+48
-0
requirements/requirements-test.txt
requirements/requirements-test.txt
+1
-0
requirements/requirements.txt
requirements/requirements.txt
+1
-0
tests/test_tensor/ops/test_cache_embedding.py
tests/test_tensor/ops/test_cache_embedding.py
+76
-0
No files found.
colossalai/nn/_ops/cache_embedding/__init__.py
0 → 100644
View file @
504419d2
from
.cache_mgr
import
CachedParamMgr
from
.copyer
import
LimitBuffIndexCopyer
__all__
=
[
'CachedParamMgr'
,
'LimitBuffIndexCopyer'
]
\ No newline at end of file
colossalai/nn/_ops/cache_embedding/base_embedding.py
0 → 100644
View file @
504419d2
import
abc
import
torch.nn
as
nn
class
BaseEmbeddingBag
(
abc
.
ABC
,
nn
.
Module
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
padding_idx
=
None
,
max_norm
=
None
,
norm_type
=
2.
,
scale_grad_by_freq
=
False
,
sparse
=
False
,
mode
=
'mean'
,
include_last_offset
=
False
,
):
super
(
BaseEmbeddingBag
,
self
).
__init__
()
self
.
num_embeddings
=
num_embeddings
self
.
embedding_dim
=
embedding_dim
if
padding_idx
is
not
None
:
if
padding_idx
>
0
:
assert
padding_idx
<
self
.
num_embeddings
,
'Padding_idx must be within num_embeddings'
elif
padding_idx
<
0
:
assert
padding_idx
>=
-
self
.
num_embeddings
,
'Padding_idx must be within num_embeddings'
padding_idx
=
self
.
num_embeddings
+
padding_idx
self
.
padding_idx
=
padding_idx
self
.
max_norm
=
max_norm
self
.
norm_type
=
norm_type
self
.
scale_grad_by_freq
=
scale_grad_by_freq
self
.
sparse
=
sparse
# Specific to embedding bag
self
.
mode
=
mode
self
.
include_last_offset
=
include_last_offset
colossalai/nn/_ops/cache_embedding/cache_mgr.py
0 → 100644
View file @
504419d2
import
numpy
as
np
import
torch
from
torch.profiler
import
record_function
from
typing
import
List
,
Optional
from
contexttimer
import
Timer
from
.copyer
import
LimitBuffIndexCopyer
class
CachedParamMgr
(
torch
.
nn
.
Module
):
"""
Manage Embedding Weights in Cache on CPU and CUDA memory.
CPU maintains entire original weight.
CUDA maintains a fraction of weights used in the upcomming computation.
During training, GPU needs to transmit rows between CPU and GPU.
"""
def
__init__
(
self
,
weight
:
torch
.
Tensor
,
cuda_row_num
:
int
=
0
,
buffer_size
:
int
=
50_000
)
->
None
:
super
(
CachedParamMgr
,
self
).
__init__
()
self
.
buffer_size
=
buffer_size
self
.
num_embeddings
,
self
.
embedding_dim
=
weight
.
shape
self
.
cuda_row_num
=
cuda_row_num
self
.
_cuda_available_row_num
=
self
.
cuda_row_num
self
.
elem_size_in_byte
=
weight
.
element_size
()
self
.
cuda_cached_weight
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
self
.
cuda_row_num
,
self
.
embedding_dim
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
weight
.
dtype
))
if
weight
.
device
.
type
==
'cuda'
:
weight
=
weight
.
cpu
()
# pin memory cpu for higher CPU-GPU copy bandwidth
self
.
cpu_weight
=
weight
.
contiguous
().
pin_memory
()
# map original id to new id with respect to frequency
# id -> cpu_row_idx
self
.
register_buffer
(
"idx_map"
,
torch
.
arange
(
self
.
num_embeddings
,
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
()),
persistent
=
False
,
)
# cached_idx_map: gpu_row_idx -> cpu_row_idx
self
.
register_buffer
(
"cached_idx_map"
,
torch
.
empty
(
self
.
cuda_row_num
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
long
).
fill_
(
-
1
),
persistent
=
False
)
# cpu_row_id -> gpu_row_idx.
# gpu_row_idx as -1 means cpu_row_id not in CUDA.
self
.
register_buffer
(
"inverted_cached_idx"
,
torch
.
zeros
(
self
.
num_embeddings
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
long
).
fill_
(
-
1
),
persistent
=
False
)
self
.
evict_backlist
=
torch
.
tensor
([],
device
=
torch
.
cuda
.
current_device
())
# index copy buffer size should less than 10% of cuda weight.
if
self
.
buffer_size
>
0
:
self
.
limit_buff_index_copyer
=
LimitBuffIndexCopyer
(
self
.
buffer_size
)
self
.
num_hits_history
=
[]
self
.
num_miss_history
=
[]
self
.
num_write_back_history
=
[]
self
.
input_id_percent_in_load_chunk
=
[]
self
.
_reset_comm_stats
()
def
cpu_weight_data
(
self
,
chunk_id
:
int
)
->
torch
.
Tensor
:
"""
access a chunk of CPU weight.
Args:
chunk_id (int): chunk id
Returns:
torch.Tensor: a piece of memory in CPU weight corresponding to chunk id's payload. The tensor is 1-D.
"""
return
self
.
cpu_weight
.
data
.
view
(
-
1
).
narrow
(
0
,
int
(
chunk_id
)
*
self
.
embedding_dim
,
self
.
embedding_dim
).
view
(
1
,
self
.
embedding_dim
)
@
property
def
cuda_available_chunk_num
(
self
):
return
self
.
_cuda_available_row_num
@
torch
.
no_grad
()
def
reorder
(
self
,
ids_freq_mapping
:
Optional
[
List
[
int
]]
=
None
,
warmup_ratio
=
0.7
):
"""reorder the cpu_weight according to ids' frequency in dataset before training.
Also Build the IndexMappingTable, aka index_mapping_table.
Execute only once before training.
Args:
ids_freq_mapping (List[int]): a list, idx is id number, value is freq. if None no reorder
warmup_ratio (float): the amount of chunks preloaded in cuda cache
"""
if
ids_freq_mapping
is
not
None
:
tmp_idx
=
torch
.
argsort
(
torch
.
from_numpy
(
ids_freq_mapping
).
cuda
(),
descending
=
True
)
sorted_idx
=
torch
.
argsort
(
tmp_idx
)
self
.
idx_map
.
data
.
copy_
(
sorted_idx
)
# TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks.
# As cuda_cached_weight is very big. You may not have that much available memory!
# Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda
preload_row_num
=
min
(
int
(
np
.
ceil
(
self
.
cuda_row_num
*
warmup_ratio
)),
self
.
num_embeddings
)
if
preload_row_num
>
0
:
with
Timer
()
as
timer
:
# extract chunks from cpu weight
preload_row_ids
=
torch
.
arange
(
preload_row_num
)
preload_slot_ids
=
preload_row_ids
.
cuda
()
if
self
.
buffer_size
>
0
:
self
.
limit_buff_index_copyer
.
index_copy
(
0
,
src_index
=
preload_row_ids
,
tgt_index
=
preload_slot_ids
,
src
=
self
.
cpu_weight
.
view
(
self
.
num_embeddings
,
-
1
),
tgt
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
))
else
:
preload_chunks
=
self
.
cpu_weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_select
(
0
,
preload_row_ids
).
cuda
()
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_copy_
(
0
,
preload_slot_ids
,
preload_chunks
)
# update auxiliary info
slot_offsets
=
preload_slot_ids
self
.
cached_idx_map
[
preload_slot_ids
]
=
preload_slot_ids
self
.
inverted_cached_idx
[
preload_slot_ids
]
=
slot_offsets
self
.
_cuda_available_row_num
-=
preload_row_num
print
(
f
'Cache warmup finished cost
{
timer
.
elapsed
}
sec.'
)
def
flush
(
self
):
"""flush all CUDA chunks to CPU.
The function is usually called after training finished.
"""
slots
=
torch
.
nonzero
(
self
.
cached_idx_map
>
-
1
).
squeeze
(
1
)
chunk_ids
=
self
.
cached_idx_map
[
slots
]
chunks
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_select
(
0
,
slots
).
cpu
()
self
.
cpu_weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_copy_
(
0
,
chunk_ids
.
cpu
(),
chunks
)
self
.
cached_idx_map
.
index_fill_
(
0
,
slots
,
-
1
)
self
.
inverted_cached_idx
.
index_fill_
(
0
,
chunk_ids
,
-
1
)
self
.
_cuda_available_row_num
+=
slots
.
numel
()
assert
self
.
_cuda_available_row_num
==
self
.
cuda_row_num
assert
torch
.
all
(
self
.
inverted_cached_idx
==
-
1
).
item
()
assert
torch
.
all
(
self
.
cached_idx_map
==
-
1
).
item
()
def
print_comm_stats
(
self
):
if
self
.
_cuda_to_cpu_numel
>
0
:
print
(
f
"CUDA->CPU BWD
{
self
.
_cuda_to_cpu_numel
*
self
.
elem_size_in_byte
/
1e6
/
self
.
_cuda_to_cpu_elapse
}
MB/s
{
self
.
_cuda_to_cpu_numel
/
1e6
}
M elem"
)
if
self
.
_cpu_to_cuda_numel
>
0
:
print
(
f
"CPU->CUDA BWD
{
self
.
_cpu_to_cuda_numel
*
self
.
elem_size_in_byte
/
1e6
/
self
.
_cpu_to_cuda_elpase
}
MB/s
{
self
.
_cpu_to_cuda_numel
/
1e6
}
M elem"
)
@
torch
.
no_grad
()
def
_id_to_cached_cuda_id
(
self
,
ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
convert ids to indices in self.cuda_cached_weight.
Implemented with parallel operations on GPU.
Args:
ids (torch.Tensor): ids from the dataset
Returns:
torch.Tensor: contains indices in self.cuda_cached_weight
"""
ids
=
self
.
idx_map
.
index_select
(
0
,
ids
.
view
(
-
1
))
ret
=
self
.
inverted_cached_idx
.
index_select
(
0
,
ids
)
return
ret
@
torch
.
no_grad
()
def
prepare_ids
(
self
,
ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
move the cpu embedding rows w.r.t. ids into CUDA memory
Args:
ids (torch.Tensor): the ids to be computed
Returns:
torch.Tensor: indices on the cuda_cached_weight.
"""
with
record_function
(
"(zhg) get unique indices"
):
cpu_row_idxs
=
torch
.
unique
(
self
.
idx_map
.
index_select
(
0
,
ids
))
assert
len
(
cpu_row_idxs
)
<=
self
.
cuda_row_num
,
\
f
"the input indices pull
{
len
(
cpu_row_idxs
)
}
chunks, "
\
f
"which is larger than the presented
{
self
.
cuda_row_num
}
, "
\
f
"please increase cuda_row_num shrink batch size"
self
.
evict_backlist
=
cpu_row_idxs
with
record_function
(
"(zhg) get cpu chunk indices"
):
comm_cpu_row_idxs
=
cpu_row_idxs
[
torch
.
isin
(
cpu_row_idxs
,
self
.
cached_idx_map
,
invert
=
True
)]
self
.
num_hits_history
.
append
(
len
(
cpu_row_idxs
)
-
len
(
comm_cpu_row_idxs
))
self
.
num_miss_history
.
append
(
len
(
comm_cpu_row_idxs
))
self
.
num_write_back_history
.
append
(
0
)
# move sure the cuda chunk will not be evicted!
with
record_function
(
"(zhg) cache update"
):
self
.
_prepare_rows_on_cuda
(
comm_cpu_row_idxs
)
self
.
evict_backlist
=
torch
.
tensor
([],
device
=
cpu_row_idxs
.
device
,
dtype
=
cpu_row_idxs
.
dtype
)
# new ids chunk_offset + offset_in_chunk
with
record_function
(
"(zhg) embed idx -> cache chunk id"
):
gpu_row_idxs
=
self
.
_id_to_cached_cuda_id
(
ids
)
return
gpu_row_idxs
def
_reset_comm_stats
(
self
):
self
.
_cpu_to_cuda_numel
=
0
self
.
_cpu_to_cuda_elpase
=
0
self
.
_cuda_to_cpu_elapse
=
0
self
.
_cuda_to_cpu_numel
=
0
def
_chunk_in_cuda
(
self
,
chunk_id
:
int
)
->
bool
:
return
self
.
inverted_cached_idx
[
chunk_id
]
!=
-
1
@
torch
.
no_grad
()
def
_prepare_rows_on_cuda
(
self
,
cpu_row_idxs
:
torch
.
Tensor
)
->
None
:
"""prepare rows in cpu_row_idxs on CUDA memory
Args:
cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA
"""
evict_num
=
cpu_row_idxs
.
numel
()
-
self
.
cuda_available_chunk_num
if
evict_num
>
0
:
with
Timer
()
as
timer
:
mask_cpu_row_idx
=
torch
.
isin
(
self
.
cached_idx_map
,
self
.
evict_backlist
)
backup_idxs
=
self
.
cached_idx_map
[
mask_cpu_row_idx
].
clone
()
invalid_idxs
=
torch
.
nonzero
(
mask_cpu_row_idx
).
squeeze
(
1
)
self
.
cached_idx_map
.
index_fill_
(
0
,
invalid_idxs
,
-
2
)
evict_gpu_row_idxs
=
torch
.
argsort
(
self
.
cached_idx_map
,
descending
=
True
)[:
evict_num
]
self
.
cached_idx_map
.
index_copy_
(
0
,
invalid_idxs
,
backup_idxs
)
evict_info
=
self
.
cached_idx_map
[
evict_gpu_row_idxs
]
if
self
.
buffer_size
>
0
:
self
.
limit_buff_index_copyer
.
index_copy
(
0
,
src_index
=
evict_gpu_row_idxs
,
tgt_index
=
evict_info
.
cpu
(),
src
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
),
tgt
=
self
.
cpu_weight
.
view
(
self
.
num_embeddings
,
-
1
))
else
:
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
rows
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_select
(
0
,
evict_gpu_row_idxs
).
cpu
()
self
.
cpu_weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_copy_
(
0
,
evict_info
.
cpu
(),
rows
)
self
.
cached_idx_map
.
index_fill_
(
0
,
evict_gpu_row_idxs
,
-
1
)
self
.
inverted_cached_idx
.
index_fill_
(
0
,
evict_info
,
-
1
)
self
.
_cuda_available_row_num
+=
evict_num
weight_size
=
evict_gpu_row_idxs
.
numel
()
*
self
.
embedding_dim
self
.
_cuda_to_cpu_elapse
+=
timer
.
elapsed
self
.
_cuda_to_cpu_numel
+=
weight_size
# print(f"evict embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
with
Timer
()
as
timer
:
slots
=
torch
.
nonzero
(
self
.
cached_idx_map
==
-
1
).
squeeze
(
1
)[:
cpu_row_idxs
.
numel
()]
# Here also allocate extra memory on CUDA. #cpu_row_idxs
if
self
.
buffer_size
>
0
:
self
.
limit_buff_index_copyer
.
index_copy
(
0
,
src_index
=
cpu_row_idxs
.
cpu
(),
tgt_index
=
slots
,
src
=
self
.
cpu_weight
.
view
(
self
.
num_embeddings
,
-
1
),
tgt
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
))
else
:
rows
=
self
.
cpu_weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_select
(
0
,
cpu_row_idxs
.
cpu
()).
cuda
()
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_copy_
(
0
,
slots
,
rows
)
slot_offsets
=
slots
self
.
cached_idx_map
[
slots
]
=
cpu_row_idxs
self
.
inverted_cached_idx
.
index_copy_
(
0
,
cpu_row_idxs
,
slot_offsets
)
self
.
_cuda_available_row_num
-=
cpu_row_idxs
.
numel
()
self
.
_cpu_to_cuda_elpase
+=
timer
.
elapsed
weight_size
=
cpu_row_idxs
.
numel
()
*
self
.
embedding_dim
self
.
_cpu_to_cuda_numel
+=
weight_size
# print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
def
_evict
(
self
)
->
int
:
"""
evict one chunk from cuda to cpu.
Returns:
(int) : the slot id be evicted.
"""
mask
=
torch
.
logical_or
(
torch
.
isin
(
self
.
cached_idx_map
,
self
.
evict_backlist
),
self
.
cached_idx_map
==
-
1
)
buf
=
self
.
cached_idx_map
[
mask
].
clone
()
idx
=
torch
.
nonzero
(
mask
).
squeeze
(
1
)
self
.
cached_idx_map
.
index_fill_
(
0
,
idx
,
-
1
)
max_row
,
max_cpu_row_idx
=
torch
.
max
(
self
.
cached_idx_map
,
dim
=
0
)
max_gpu_row_idx
=
self
.
cached_idx_map
[
max_cpu_row_idx
]
if
max_gpu_row_idx
==
-
1
:
raise
RuntimeError
(
"Can not evict a row"
)
max_gpu_row_idx
=
max_gpu_row_idx
.
item
()
max_offset
=
self
.
inverted_cached_idx
[
max_gpu_row_idx
]
# recover
self
.
cached_idx_map
.
index_copy_
(
0
,
idx
,
buf
)
with
Timer
()
as
timer
:
cuda_tensor
=
torch
.
narrow
(
self
.
cuda_cached_weight
.
view
(
-
1
),
0
,
max_offset
*
self
.
embedding_dim
,
self
.
embedding_dim
).
view
(
1
,
self
.
embedding_dim
)
self
.
cpu_weight_data
(
max_gpu_row_idx
).
data
.
copy_
(
cuda_tensor
)
# update inverted_cached_idx, min_slot_id is evicted from cuda
self
.
cached_idx_map
[
max_cpu_row_idx
]
=
-
1
self
.
inverted_cached_idx
[
max_gpu_row_idx
]
=
-
1
self
.
_cuda_available_row_num
+=
1
self
.
_cuda_to_cpu_numel
+=
self
.
embedding_dim
self
.
_cuda_to_cpu_elapse
+=
timer
.
elapsed
# self.num_write_back_history[-1] += 1
return
max_cpu_row_idx
def
_find_free_cuda_row
(
self
)
->
int
:
if
self
.
_cuda_available_row_num
==
0
:
return
-
1
candidates
=
torch
.
nonzero
(
self
.
cached_idx_map
==
-
1
).
squeeze
(
1
)
return
candidates
[
0
].
item
()
@
torch
.
no_grad
()
def
_admit
(
self
,
row_id
:
int
):
"""
move in row_id to CUDA
Args:
row_id (int): the id of row to be moved in
"""
# find a free slot in partial cuda weight
slot_id
=
self
.
_find_free_cuda_row
()
if
slot_id
==
-
1
:
# evict one row
slot_id
=
self
.
_evict
()
slot_offset
=
slot_id
# copy payload from cpu to cuda
with
Timer
()
as
timer
:
cuda_tensor
=
torch
.
narrow
(
self
.
cuda_cached_weight
.
view
(
-
1
),
0
,
slot_offset
*
self
.
embedding_dim
,
self
.
embedding_dim
).
view
(
1
,
self
.
embedding_dim
)
cuda_tensor
.
data
.
copy_
(
self
.
cpu_weight_data
(
row_id
))
# update the inverted_cached_idx
self
.
cached_idx_map
[
slot_id
]
=
row_id
self
.
inverted_cached_idx
[
row_id
]
=
slot_offset
self
.
_cuda_available_row_num
-=
1
self
.
_cpu_to_cuda_numel
+=
self
.
embedding_dim
self
.
_cpu_to_cuda_elpase
+=
timer
.
elapsed
colossalai/nn/_ops/cache_embedding/copyer.py
0 → 100644
View file @
504419d2
import
torch
from
torch
import
LongTensor
class
LimitBuffIndexCopyer
(
object
):
"""LimitBuffIndexCopyer
Index Copy using limited temp buffer on CUDA.
Args:
size (int): buffer size
"""
def
__init__
(
self
,
size
:
int
)
->
None
:
self
.
_buff_size
=
size
@
torch
.
no_grad
()
def
index_copy
(
self
,
dim
:
int
,
src_index
:
LongTensor
,
tgt_index
:
LongTensor
,
src
:
torch
.
Tensor
,
tgt
:
torch
.
Tensor
):
"""copy
src tensor[src_index] -(index_select)-> tmp -()-> tgt tensor [tgt_index]
The valid part in src is continous, while in tgt is scatter.
Args:
dim (int): dimension along which to index
src_index (int): indices of src tensor to select from
tgt_index (int): indices of tgt tensor to select from
src (torch.Tensor): the tensor containing values to copy
tgt (torch.Tensor): the tensor to be copied
"""
# tgt.index_copy_(dim, index, src)
assert
dim
==
0
,
"only support index_copy on dim 0"
assert
tgt
.
dim
()
==
2
assert
src
.
dim
()
==
2
tgt_device
=
tgt
.
device
src_device
=
src
.
device
assert
src_index
.
numel
()
==
tgt_index
.
numel
()
dim_size
=
src_index
.
numel
()
src_index
=
src_index
.
to
(
src_device
)
for
begin_pos
in
range
(
0
,
dim_size
,
self
.
_buff_size
):
cur_len
=
min
(
self
.
_buff_size
,
dim_size
-
begin_pos
)
src_idx_piece
=
src_index
.
narrow
(
0
,
begin_pos
,
cur_len
)
if
src_device
.
type
==
'cpu'
and
tgt_device
.
type
==
'cuda'
:
cpu_tmp_buffer
=
src
.
index_select
(
dim
,
src_idx_piece
).
pin_memory
()
tmp_buffer
=
torch
.
empty_like
(
cpu_tmp_buffer
,
device
=
tgt_device
)
tmp_buffer
.
copy_
(
cpu_tmp_buffer
)
else
:
tmp_buffer
=
src
.
index_select
(
dim
,
src_idx_piece
).
to
(
tgt_device
)
tgt_idx_piece
=
tgt_index
.
narrow
(
0
,
begin_pos
,
cur_len
)
tgt
.
index_copy_
(
dim
,
tgt_idx_piece
,
tmp_buffer
)
requirements/requirements-test.txt
View file @
504419d2
...
...
@@ -5,3 +5,4 @@ timm
titans
torchaudio
torchrec
contexttimer
requirements/requirements.txt
View file @
504419d2
...
...
@@ -7,3 +7,4 @@ pre-commit
rich
click
fabric
contexttimer
\ No newline at end of file
tests/test_tensor/ops/test_cache_embedding.py
0 → 100644
View file @
504419d2
import
pytest
from
functools
import
partial
import
torch
import
torch.multiprocessing
as
mp
import
numpy
as
np
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.nn._ops.cache_embedding
import
CachedParamMgr
NUM_EMBED
,
EMBED_DIM
=
100
,
8
BATCH_SIZE
=
8
def
test_cachemgr
():
model
=
torch
.
nn
.
EmbeddingBag
(
10000
,
128
)
# 10 chunks, 5 in cuda
mgr
=
CachedParamMgr
(
model
.
weight
,
5
)
assert
mgr
.
cuda_row_num
==
5
mgr
.
_admit
(
1
)
assert
not
mgr
.
_chunk_in_cuda
(
2
)
assert
mgr
.
_chunk_in_cuda
(
1
)
# print(mgr.cached_chunk_table)
mgr
.
_admit
(
8
)
# now 3 chunk is available
assert
mgr
.
cuda_available_chunk_num
==
3
mgr
.
_evict
()
assert
mgr
.
cuda_available_chunk_num
==
4
mgr
.
_prepare_rows_on_cuda
(
torch
.
tensor
([
9
,
6
,
5
],
dtype
=
torch
.
long
,
device
=
0
))
mgr
.
_prepare_rows_on_cuda
(
torch
.
tensor
([
3
,
4
,
5
],
dtype
=
torch
.
long
,
device
=
0
))
# print(mgr.cached_chunk_table)
# mgr.print_comm_stats()
mgr
.
flush
()
assert
mgr
.
cuda_available_chunk_num
==
5
def
test_reorder_with_freq
():
num_embed
=
100
chunk_size
=
1
num_chunk
=
5
idx_map
=
np
.
random
.
randint
(
10000
,
size
=
(
num_embed
,))
sorted_idx
=
np
.
flipud
(
np
.
argsort
(
idx_map
)).
tolist
()
chunkid
,
offset_in_chunk
=
[],
[]
for
i
in
range
(
num_embed
):
idx
=
sorted_idx
.
index
(
i
)
chunkid
.
append
(
idx
//
chunk_size
)
offset_in_chunk
.
append
(
idx
%
chunk_size
)
chunkid
=
torch
.
tensor
(
chunkid
,
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
())
offset_in_chunk
=
torch
.
tensor
(
offset_in_chunk
,
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
())
weight
=
torch
.
rand
(
num_embed
,
2
)
mgr
=
CachedParamMgr
(
weight
,
num_chunk
)
mgr
.
reorder
(
idx_map
)
indices
=
mgr
.
idx_map
.
index_select
(
0
,
torch
.
arange
(
num_embed
,
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
()))
mgr_chunk_id
=
torch
.
div
(
indices
,
chunk_size
,
rounding_mode
=
'floor'
)
mgr_offsets
=
torch
.
remainder
(
indices
,
chunk_size
)
assert
torch
.
allclose
(
chunkid
,
mgr_chunk_id
),
f
"chunk id:
{
chunkid
}
, mgr:
{
mgr_chunk_id
}
"
assert
torch
.
allclose
(
offset_in_chunk
,
mgr_offsets
),
\
f
"offset in chunk:
{
offset_in_chunk
}
, mgr:
{
mgr_offsets
}
"
if
__name__
==
'__main__'
:
# test_freq_aware_embed()
# test_chunkmgr_admit()
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment