Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
363fc286
"...config/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "aacc5d761a4774047dd583a6607c44af368d08ab"
Unverified
Commit
363fc286
authored
Oct 12, 2022
by
Jiarui Fang
Committed by
GitHub
Oct 12, 2022
Browse files
[embeddings] more detailed timer (#1692)
parent
4973157a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
65 additions
and
51 deletions
+65
-51
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+65
-51
No files found.
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
View file @
363fc286
...
...
@@ -91,6 +91,7 @@ class CachedParamMgr(torch.nn.Module):
dtype
=
torch
.
long
).
fill_
(
sys
.
maxsize
),
persistent
=
False
)
self
.
_elapsed_dict
=
{}
self
.
_show_cache_miss
=
True
self
.
_reset_comm_stats
()
def
_reset_comm_stats
(
self
):
...
...
@@ -99,6 +100,9 @@ class CachedParamMgr(torch.nn.Module):
self
.
_cpu_to_cuda_numel
=
0
self
.
_cuda_to_cpu_numel
=
0
if
self
.
_show_cache_miss
:
self
.
_cache_miss
=
0
self
.
_total_cache
=
0
@
contextmanager
def
timer
(
self
,
name
):
...
...
@@ -268,6 +272,10 @@ class CachedParamMgr(torch.nn.Module):
self
.
inverted_cached_idx
.
index_fill_
(
0
,
row_ids
,
-
1
)
self
.
_cuda_available_row_num
+=
slots
.
numel
()
if
self
.
_show_cache_miss
:
self
.
_cache_miss
=
0
self
.
_total_cache
=
0
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
self
.
freq_cnter
.
fill_
(
sys
.
maxsize
)
assert
self
.
_cuda_available_row_num
==
self
.
cuda_row_num
...
...
@@ -275,14 +283,14 @@ class CachedParamMgr(torch.nn.Module):
assert
torch
.
all
(
self
.
cached_idx_map
==
-
1
).
item
()
def
print_comm_stats
(
self
):
if
self
.
_cuda_to_cpu_numel
>
0
and
"3_
2_2_
evict_out
_gpu_to_cpu_copy
"
in
self
.
_elapsed_dict
:
elapsed
=
self
.
_elapsed_dict
[
"3_
2_2_
evict_out
_gpu_to_cpu_copy
"
]
if
self
.
_cuda_to_cpu_numel
>
0
and
"3_evict_out"
in
self
.
_elapsed_dict
:
elapsed
=
self
.
_elapsed_dict
[
"3_evict_out"
]
print
(
f
"CUDA->CPU BWD
{
self
.
_cuda_to_cpu_numel
*
self
.
elem_size_in_byte
/
1e6
/
elapsed
}
MB/s
{
self
.
_cuda_to_cpu_numel
/
1e6
}
M elem"
)
print
(
f
'cuda_to_cpu_elapse
{
elapsed
}
sec'
)
if
self
.
_cpu_to_cuda_numel
>
0
and
"
3_4_2
_evict_in
_gpu_to_cpu_copy
"
in
self
.
_elapsed_dict
:
elapsed
=
self
.
_elapsed_dict
[
"
3_4_2
_evict_in
_gpu_to_cpu_copy
"
]
if
self
.
_cpu_to_cuda_numel
>
0
and
"
5
_evict_in"
in
self
.
_elapsed_dict
:
elapsed
=
self
.
_elapsed_dict
[
"
5
_evict_in"
]
print
(
f
"CPU->CUDA BWD
{
self
.
_cpu_to_cuda_numel
*
self
.
elem_size_in_byte
/
1e6
/
elapsed
}
MB/s
{
self
.
_cpu_to_cuda_numel
/
1e6
}
M elem"
)
...
...
@@ -291,6 +299,8 @@ class CachedParamMgr(torch.nn.Module):
for
k
,
v
in
self
.
_elapsed_dict
.
items
():
print
(
f
'
{
k
}
:
{
v
}
'
)
print
(
f
'cache miss ratio
{
self
.
_cache_miss
/
self
.
_total_cache
}
'
)
@
torch
.
no_grad
()
def
_id_to_cached_cuda_id
(
self
,
ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
...
...
@@ -315,41 +325,45 @@ class CachedParamMgr(torch.nn.Module):
torch.Tensor: indices on the cuda_cached_weight.
"""
torch
.
cuda
.
synchronize
()
with
self
.
timer
(
"1_unique_indices"
)
as
timer
:
with
record_function
(
"(cache) get unique indices"
):
cpu_row_idxs
,
repeat_times
=
torch
.
unique
(
self
.
idx_map
.
index_select
(
0
,
ids
),
return_counts
=
True
)
assert
len
(
cpu_row_idxs
)
<=
self
.
cuda_row_num
,
\
f
"You move
{
len
(
cpu_row_idxs
)
}
embedding rows from CPU to CUDA. "
\
f
"It is larger than the capacity of the cache, which at most contains
{
self
.
cuda_row_num
}
rows, "
\
f
"Please increase cuda_row_num or decrease the training batch size."
self
.
evict_backlist
=
cpu_row_idxs
torch
.
cuda
.
synchronize
()
# O(cache ratio)
with
self
.
timer
(
"2_cpu_row_idx"
)
as
timer
:
with
record_function
(
"(cache) get cpu row idxs"
):
comm_cpu_row_idxs
=
cpu_row_idxs
[
torch
.
isin
(
cpu_row_idxs
,
self
.
cached_idx_map
,
invert
=
True
)]
self
.
num_hits_history
.
append
(
len
(
cpu_row_idxs
)
-
len
(
comm_cpu_row_idxs
))
self
.
num_miss_history
.
append
(
len
(
comm_cpu_row_idxs
))
self
.
num_write_back_history
.
append
(
0
)
# move sure the cuda rows will not be evicted!
with
self
.
timer
(
"3_prepare_rows_on_cuda"
)
as
timer
:
with
self
.
timer
(
"cache_op"
)
as
gtimer
:
# identify cpu rows to cache
with
self
.
timer
(
"1_identify_cpu_row_idxs"
)
as
timer
:
with
record_function
(
"(cache) get unique indices"
):
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
cpu_row_idxs
,
repeat_times
=
torch
.
unique
(
ids
,
return_counts
=
True
)
else
:
cpu_row_idxs
,
repeat_times
=
torch
.
unique
(
self
.
idx_map
.
index_select
(
0
,
ids
),
return_counts
=
True
)
assert
len
(
cpu_row_idxs
)
<=
self
.
cuda_row_num
,
\
f
"You move
{
len
(
cpu_row_idxs
)
}
embedding rows from CPU to CUDA. "
\
f
"It is larger than the capacity of the cache, which at most contains
{
self
.
cuda_row_num
}
rows, "
\
f
"Please increase cuda_row_num or decrease the training batch size."
self
.
evict_backlist
=
cpu_row_idxs
tmp
=
torch
.
isin
(
cpu_row_idxs
,
self
.
cached_idx_map
,
invert
=
True
)
comm_cpu_row_idxs
=
cpu_row_idxs
[
tmp
]
if
self
.
_show_cache_miss
:
self
.
_cache_miss
+=
torch
.
sum
(
repeat_times
[
tmp
])
self
.
_total_cache
+=
ids
.
numel
()
self
.
num_hits_history
.
append
(
len
(
cpu_row_idxs
)
-
len
(
comm_cpu_row_idxs
))
self
.
num_miss_history
.
append
(
len
(
comm_cpu_row_idxs
))
self
.
num_write_back_history
.
append
(
0
)
# move sure the cuda rows will not be evicted!
with
record_function
(
"(cache) prepare_rows_on_cuda"
):
self
.
_prepare_rows_on_cuda
(
comm_cpu_row_idxs
)
self
.
evict_backlist
=
torch
.
tensor
([],
device
=
cpu_row_idxs
.
device
,
dtype
=
cpu_row_idxs
.
dtype
)
self
.
evict_backlist
=
torch
.
tensor
([],
device
=
cpu_row_idxs
.
device
,
dtype
=
cpu_row_idxs
.
dtype
)
with
self
.
timer
(
"
4_cpu_to_gpu_row_idxs
"
)
as
timer
:
with
record_function
(
"
(cache) embed cpu rows idx -> cache gpu row idxs
"
):
gpu_row_idxs
=
self
.
_id_to_cached_cuda_id
(
ids
)
with
self
.
timer
(
"
6_update_cache
"
)
as
timer
:
with
record_function
(
"
6_update_cache
"
):
gpu_row_idxs
=
self
.
_id_to_cached_cuda_id
(
ids
)
# update for LFU.
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
unique_gpu_row_idxs
=
self
.
inverted_cached_idx
[
cpu_row_idxs
]
self
.
freq_cnter
.
scatter_add_
(
0
,
unique_gpu_row_idxs
,
repeat_times
)
# update for LFU.
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
unique_gpu_row_idxs
=
self
.
inverted_cached_idx
[
cpu_row_idxs
]
self
.
freq_cnter
.
scatter_add_
(
0
,
unique_gpu_row_idxs
,
repeat_times
)
return
gpu_row_idxs
...
...
@@ -377,8 +391,7 @@ class CachedParamMgr(torch.nn.Module):
raise
NotImplemented
if
evict_num
>
0
:
torch
.
cuda
.
synchronize
()
with
self
.
timer
(
"3_1_evict_prepare"
)
as
timer
:
with
self
.
timer
(
"2_identify_cuda_row_idxs"
)
as
timer
:
mask_cpu_row_idx
=
torch
.
isin
(
self
.
cached_idx_map
,
self
.
evict_backlist
)
invalid_idxs
=
torch
.
nonzero
(
mask_cpu_row_idx
).
squeeze
(
1
)
if
self
.
_evict_strategy
==
EvictionStrategy
.
DATASET
:
...
...
@@ -388,7 +401,7 @@ class CachedParamMgr(torch.nn.Module):
backup_idxs
=
self
.
cached_idx_map
[
mask_cpu_row_idx
].
clone
()
self
.
cached_idx_map
.
index_fill_
(
0
,
invalid_idxs
,
-
2
)
with
self
.
timer
(
"
3_1
_1_find_evict_gpu_idxs
_elapsed
"
)
as
timer
:
with
self
.
timer
(
"
2
_1_find_evict_gpu_idxs"
)
as
timer
:
evict_gpu_row_idxs
=
self
.
_find_evict_gpu_idxs
(
evict_num
)
# move evict out rows to cpu
...
...
@@ -401,11 +414,11 @@ class CachedParamMgr(torch.nn.Module):
self
.
cached_idx_map
.
index_copy_
(
0
,
invalid_idxs
,
backup_idxs
)
elif
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
with
self
.
timer
(
"
3
_1_
0_
backup_freqs"
)
as
timer
:
with
self
.
timer
(
"
2
_1_backup_freqs"
)
as
timer
:
backup_freqs
=
self
.
freq_cnter
[
invalid_idxs
].
clone
()
self
.
freq_cnter
.
index_fill_
(
0
,
invalid_idxs
,
sys
.
maxsize
)
with
self
.
timer
(
"
3_1_1
_find_evict_gpu_idxs
_elapsed
"
)
as
timer
:
with
self
.
timer
(
"
2_2
_find_evict_gpu_idxs"
)
as
timer
:
evict_gpu_row_idxs
=
self
.
_find_evict_gpu_idxs
(
evict_num
)
if
self
.
_async_copy
:
...
...
@@ -414,12 +427,13 @@ class CachedParamMgr(torch.nn.Module):
evict_out_rows_cpu
=
torch
.
empty_like
(
evict_out_rows_gpu
,
device
=
'cpu'
,
pin_memory
=
True
)
with
torch
.
cuda
.
stream
(
None
):
evict_out_rows_cpu
.
copy_
(
evict_out_rows_gpu
,
non_blocking
=
True
)
with
self
.
timer
(
"3_1_2_find_evict_index_copy"
)
as
timer
:
with
self
.
timer
(
"2_3_revert_freqs"
)
as
timer
:
self
.
freq_cnter
.
index_copy_
(
0
,
invalid_idxs
,
backup_freqs
)
evict_info
=
self
.
cached_idx_map
[
evict_gpu_row_idxs
]
with
self
.
timer
(
"3_
2_
evict_out
_elapse
"
)
as
timer
:
with
self
.
timer
(
"3_evict_out"
)
as
timer
:
if
self
.
buffer_size
>
0
:
self
.
limit_buff_index_copyer
.
index_copy
(
0
,
src_index
=
evict_gpu_row_idxs
,
...
...
@@ -432,13 +446,13 @@ class CachedParamMgr(torch.nn.Module):
if
self
.
_async_copy
:
_wait_for_data
(
evict_out_rows_cpu
,
None
)
else
:
with
self
.
timer
(
"3_
2_
1_evict_out_index_select"
)
as
timer
:
with
self
.
timer
(
"3_1_evict_out_index_select"
)
as
timer
:
evict_out_rows_cpu
=
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_select
(
0
,
evict_gpu_row_idxs
)
with
self
.
timer
(
"3_2_
2_
evict_out_gpu_to_cpu_copy"
)
as
timer
:
with
self
.
timer
(
"3_2_evict_out_gpu_to_cpu_copy"
)
as
timer
:
evict_out_rows_cpu
=
evict_out_rows_cpu
.
cpu
()
with
self
.
timer
(
"3_2_
2_
evict_out_
index_select
"
)
as
timer
:
with
self
.
timer
(
"3_2_evict_out_
cpu_copy
"
)
as
timer
:
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_copy_
(
0
,
evict_info
.
cpu
(),
evict_out_rows_cpu
)
self
.
cached_idx_map
.
index_fill_
(
0
,
evict_gpu_row_idxs
,
-
1
)
...
...
@@ -447,15 +461,15 @@ class CachedParamMgr(torch.nn.Module):
self
.
_cuda_available_row_num
+=
evict_num
weight_size
=
evict_gpu_row_idxs
.
numel
()
*
self
.
embedding_dim
self
.
_cuda_to_cpu_numel
+=
weight_size
self
.
_cuda_to_cpu_numel
+=
weight_size
# print(f"evict embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
# slots of cuda weight to evict in
with
self
.
timer
(
"
3_3_non_zero
"
)
as
timer
:
with
self
.
timer
(
"
4_identify_cuda_slot
"
)
as
timer
:
slots
=
torch
.
nonzero
(
self
.
cached_idx_map
==
-
1
).
squeeze
(
1
)[:
cpu_row_idxs
.
numel
()]
# TODO wait for optimize
with
self
.
timer
(
"
3_4
_evict_in
_elapse
"
)
as
timer
:
with
self
.
timer
(
"
5
_evict_in"
)
as
timer
:
# Here also allocate extra memory on CUDA. #cpu_row_idxs
if
self
.
buffer_size
>
0
:
self
.
limit_buff_index_copyer
.
index_copy
(
0
,
...
...
@@ -467,20 +481,20 @@ class CachedParamMgr(torch.nn.Module):
if
self
.
_async_copy
:
_wait_for_data
(
evict_in_rows_gpu
,
self
.
_memcpy_stream
)
else
:
with
self
.
timer
(
"
3_4
_1_evict_in_index_select"
)
as
timer
:
with
self
.
timer
(
"
5
_1_evict_in_index_select"
)
as
timer
:
# narrow index select to a subset of self.weight
# tmp = torch.narrow(self.weight.view(self.num_embeddings, -1), 0, min(cpu_row_idxs).cpu(), max(cpu_row_idxs) - min(cpu_row_idxs) + 1)
# evict_in_rows_gpu = tmp.index_select(0, cpu_row_idxs_copy - min(cpu_row_idxs).cpu())
evict_in_rows_gpu
=
self
.
weight
.
view
(
self
.
num_embeddings
,
-
1
).
index_select
(
0
,
cpu_row_idxs_copy
).
pin_memory
()
with
self
.
timer
(
"
3_4
_2_evict_in_gpu_to_cpu_copy"
)
as
timer
:
with
self
.
timer
(
"
5
_2_evict_in_gpu_to_cpu_copy"
)
as
timer
:
evict_in_rows_gpu
=
evict_in_rows_gpu
.
cuda
()
with
self
.
timer
(
"
3_4
_3_evict_in_index_copy"
)
as
timer
:
with
self
.
timer
(
"
5
_3_evict_in_index_copy"
)
as
timer
:
self
.
cuda_cached_weight
.
view
(
self
.
cuda_row_num
,
-
1
).
index_copy_
(
0
,
slots
,
evict_in_rows_gpu
)
with
self
.
timer
(
"
3_5_evict_in_elapse_final
"
)
as
timer
:
with
self
.
timer
(
"
6_update_cache
"
)
as
timer
:
self
.
cached_idx_map
[
slots
]
=
cpu_row_idxs
self
.
inverted_cached_idx
.
index_copy_
(
0
,
cpu_row_idxs
,
slots
)
if
self
.
_evict_strategy
==
EvictionStrategy
.
LFU
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment