"docs/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "80eba05b0abc0ce24f02254cbe2c7b8f9ff5d688"
Unverified Commit ba61109b authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[FAW] remove code related to chunk (#1501)

parent d5085bb3
...@@ -56,7 +56,6 @@ class CachedParamMgr(torch.nn.Module): ...@@ -56,7 +56,6 @@ class CachedParamMgr(torch.nn.Module):
self.num_hits_history = [] self.num_hits_history = []
self.num_miss_history = [] self.num_miss_history = []
self.num_write_back_history = [] self.num_write_back_history = []
self.input_id_percent_in_load_chunk = []
self._reset_comm_stats() self._reset_comm_stats()
self._evict_strategy = evict_strategy self._evict_strategy = evict_strategy
...@@ -156,23 +155,23 @@ class CachedParamMgr(torch.nn.Module): ...@@ -156,23 +155,23 @@ class CachedParamMgr(torch.nn.Module):
# self.cuda_cached_weight = self.weight # self.cuda_cached_weight = self.weight
raise NotImplementedError() raise NotImplementedError()
def cpu_weight_data(self, chunk_id: int) -> torch.Tensor: def cpu_weight_data(self, row_idx: int) -> torch.Tensor:
""" """
access a chunk of CPU weight. access a row of CPU weight.
Args: Args:
chunk_id (int): chunk id row_idx (int): the idx of rows
Returns: Returns:
torch.Tensor: a piece of memory in CPU weight corresponding to chunk id's payload. The tensor is 1-D. torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D.
""" """
return self.weight.data.view(-1).narrow(0, return self.weight.data.view(-1).narrow(0,
int(chunk_id) * self.embedding_dim, int(row_idx) * self.embedding_dim,
self.embedding_dim).view(1, self.embedding_dim) self.embedding_dim).view(1, self.embedding_dim)
@property @property
def cuda_available_chunk_num(self): def cuda_available_row_num(self):
return self._cuda_available_row_num return self._cuda_available_row_num
@torch.no_grad() @torch.no_grad()
...@@ -202,7 +201,7 @@ class CachedParamMgr(torch.nn.Module): ...@@ -202,7 +201,7 @@ class CachedParamMgr(torch.nn.Module):
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings) preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
if preload_row_num > 0: if preload_row_num > 0:
with Timer() as timer: with Timer() as timer:
# extract chunks from cpu weight # extract rows from cpu weight
preload_row_ids = torch.arange(preload_row_num) preload_row_ids = torch.arange(preload_row_num)
preload_slot_ids = preload_row_ids.cuda() preload_slot_ids = preload_row_ids.cuda()
...@@ -213,8 +212,8 @@ class CachedParamMgr(torch.nn.Module): ...@@ -213,8 +212,8 @@ class CachedParamMgr(torch.nn.Module):
src=self.weight.view(self.num_embeddings, -1), src=self.weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
else: else:
preload_chunks = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda() preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda()
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_chunks) self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_rows)
# update auxiliary info # update auxiliary info
slot_offsets = preload_slot_ids slot_offsets = preload_slot_ids
...@@ -224,15 +223,15 @@ class CachedParamMgr(torch.nn.Module): ...@@ -224,15 +223,15 @@ class CachedParamMgr(torch.nn.Module):
print(f'Cache warmup finished cost {timer.elapsed} sec.') print(f'Cache warmup finished cost {timer.elapsed} sec.')
def flush(self): def flush(self):
"""flush all CUDA chunks to CPU. """flush all CUDA rows to CPU.
The function is usually called after training finished. The function is usually called after training finished.
""" """
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1) slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1)
chunk_ids = self.cached_idx_map[slots] row_ids = self.cached_idx_map[slots]
chunks = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu() rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
self.weight.view(self.num_embeddings, -1).index_copy_(0, chunk_ids.cpu(), chunks) self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows)
self.cached_idx_map.index_fill_(0, slots, -1) self.cached_idx_map.index_fill_(0, slots, -1)
self.inverted_cached_idx.index_fill_(0, chunk_ids, -1) self.inverted_cached_idx.index_fill_(0, row_ids, -1)
self._cuda_available_row_num += slots.numel() self._cuda_available_row_num += slots.numel()
assert self._cuda_available_row_num == self.cuda_row_num assert self._cuda_available_row_num == self.cuda_row_num
...@@ -280,25 +279,25 @@ class CachedParamMgr(torch.nn.Module): ...@@ -280,25 +279,25 @@ class CachedParamMgr(torch.nn.Module):
cpu_row_idxs = torch.unique(cpu_row_idxs_original) cpu_row_idxs = torch.unique(cpu_row_idxs_original)
assert len(cpu_row_idxs) <= self.cuda_row_num, \ assert len(cpu_row_idxs) <= self.cuda_row_num, \
f"the input indices pull {len(cpu_row_idxs)} chunks, " \ f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \
f"which is larger than the presented {self.cuda_row_num}, " \ f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \
f"please increase cuda_row_num shrink batch size" f"Please increase cuda_row_num or decrease the training batch size."
self.evict_backlist = cpu_row_idxs self.evict_backlist = cpu_row_idxs
with record_function("(zhg) get cpu chunk indices"): with record_function("(zhg) get cpu row idxs"):
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)] comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs)) self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
self.num_miss_history.append(len(comm_cpu_row_idxs)) self.num_miss_history.append(len(comm_cpu_row_idxs))
self.num_write_back_history.append(0) self.num_write_back_history.append(0)
# move sure the cuda chunk will not be evicted! # move sure the cuda rows will not be evicted!
with record_function("(zhg) cache update"): with record_function("(zhg) cache update"):
self._prepare_rows_on_cuda(comm_cpu_row_idxs) self._prepare_rows_on_cuda(comm_cpu_row_idxs)
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype) self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
# new ids chunk_offset + offset_in_chunk
with record_function("(zhg) embed idx -> cache chunk id"): with record_function("(zhg) embed cpu rows idx -> cache gpu row idxs"):
gpu_row_idxs = self._id_to_cached_cuda_id(ids) gpu_row_idxs = self._id_to_cached_cuda_id(ids)
# update for LFU. # update for LFU.
...@@ -311,17 +310,17 @@ class CachedParamMgr(torch.nn.Module): ...@@ -311,17 +310,17 @@ class CachedParamMgr(torch.nn.Module):
self._cuda_to_cpu_elapse = 0 self._cuda_to_cpu_elapse = 0
self._cuda_to_cpu_numel = 0 self._cuda_to_cpu_numel = 0
def _chunk_in_cuda(self, chunk_id: int) -> bool: def _row_in_cuda(self, row_id: int) -> bool:
return self.inverted_cached_idx[chunk_id] != -1 return self.inverted_cached_idx[row_id] != -1
@torch.no_grad() @torch.no_grad()
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None:
"""prepare rows in cpu_row_idxs on CUDA memory """prepare rows in cpu_row_idxs on CUDA memory
Args: Args:
cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA
""" """
evict_num = cpu_row_idxs.numel() - self.cuda_available_chunk_num evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
if evict_num > 0: if evict_num > 0:
with Timer() as timer: with Timer() as timer:
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
...@@ -396,7 +395,7 @@ class CachedParamMgr(torch.nn.Module): ...@@ -396,7 +395,7 @@ class CachedParamMgr(torch.nn.Module):
""" """
deprecated deprecated
evict one chunk from cuda to cpu. evict one row from cuda to cpu.
Returns: Returns:
(int) : the slot id be evicted. (int) : the slot id be evicted.
""" """
......
...@@ -119,8 +119,4 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): ...@@ -119,8 +119,4 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
if self.cache_weight_mgr._cuda_to_cpu_numel > 0: if self.cache_weight_mgr._cuda_to_cpu_numel > 0:
return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \ return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \
self.cache_weight_mgr._cuda_to_cpu_elapse self.cache_weight_mgr._cuda_to_cpu_elapse
return 0 return 0
\ No newline at end of file
@property
def input_id_percent_in_load_chunk(self):
return 0 # np.mean(self.cache_weight_mgr.input_id_percent_in_load_chunk) * 100
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment