Unverified Commit cde7b8a5 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[FAW] init an LFU implementation for FAW (#1488)

parent 32efe8e7
...@@ -3,10 +3,10 @@ from .linear import ColoLinear ...@@ -3,10 +3,10 @@ from .linear import ColoLinear
from .embedding import ColoEmbedding from .embedding import ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy
__all__ = [ __all__ = [
'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', 'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr', 'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr',
'LimitBuffIndexCopyer' 'LimitBuffIndexCopyer', 'EvictionStrategy'
] ]
from .cache_mgr import CachedParamMgr from .cache_mgr import CachedParamMgr, EvictionStrategy
from .copyer import LimitBuffIndexCopyer from .copyer import LimitBuffIndexCopyer
from .freq_aware_embedding import FreqAwareEmbeddingBag from .freq_aware_embedding import FreqAwareEmbeddingBag
from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag
__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag'] __all__ = [
'CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag',
'EvictionStrategy'
]
...@@ -4,6 +4,12 @@ from torch.profiler import record_function ...@@ -4,6 +4,12 @@ from torch.profiler import record_function
from typing import List, Optional from typing import List, Optional
from contexttimer import Timer from contexttimer import Timer
from .copyer import LimitBuffIndexCopyer from .copyer import LimitBuffIndexCopyer
from enum import Enum
class EvictionStrategy(Enum):
LFU = 1
DATASET = 2
class CachedParamMgr(torch.nn.Module): class CachedParamMgr(torch.nn.Module):
...@@ -18,7 +24,8 @@ class CachedParamMgr(torch.nn.Module): ...@@ -18,7 +24,8 @@ class CachedParamMgr(torch.nn.Module):
weight: torch.Tensor, weight: torch.Tensor,
cuda_row_num: int = 0, cuda_row_num: int = 0,
buffer_size: int = 50_000, buffer_size: int = 50_000,
pin_weight=False) -> None: pin_weight=False,
evict_strategy=EvictionStrategy.DATASET) -> None:
super(CachedParamMgr, self).__init__() super(CachedParamMgr, self).__init__()
self.buffer_size = buffer_size self.buffer_size = buffer_size
self.num_embeddings, self.embedding_dim = weight.shape self.num_embeddings, self.embedding_dim = weight.shape
...@@ -38,6 +45,51 @@ class CachedParamMgr(torch.nn.Module): ...@@ -38,6 +45,51 @@ class CachedParamMgr(torch.nn.Module):
self.input_id_percent_in_load_chunk = [] self.input_id_percent_in_load_chunk = []
self._reset_comm_stats() self._reset_comm_stats()
self._evict_strategy = evict_strategy
if self._evict_strategy == EvictionStrategy.LFU:
# cpu_row_idx -> frequency, freq of the cpu rows.
# evict the minimal freq value row in cuda cache.
self.register_buffer("freq_cnter",
torch.empty(self.num_embeddings, device=torch.cuda.current_device(),
dtype=torch.long).fill_(0),
persistent=False)
def _update_freq_cnter(self, cpu_row_idxs: torch.Tensor) -> None:
"""_update_freq_cnter
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
Args:
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
"""
if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter[cpu_row_idxs] += 1
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
"""_find_evict_gpu_idxs
Find the gpu idxs to be evicted, according to their freq.
Args:
evict_num (int): how many rows has to be evicted
Returns:
torch.Tensor: a list tensor (1D), contains the gpu_row_idxs.
"""
if self._evict_strategy == EvictionStrategy.LFU:
# find the minimal evict_num freq entries in cached_idx_map
evict_gpu_row_idxs = torch.argsort(self.freq_cnter[self.cached_idx_map])[:evict_num]
return self.cached_idx_map[evict_gpu_row_idxs]
elif self._evict_strategy == EvictionStrategy.DATASET:
# cached_idx_map itself implies the priority of eviction.
# The value of self.cached_idx_map represents cpu_row_idx.
# The larger it is, the less frequently it will appear in the dataset,
# and the higher its eviction priority will be.
return torch.argsort(self.cached_idx_map, descending=True)[:evict_num]
else:
raise TypeError
def _init_weight(self, weight): def _init_weight(self, weight):
if self.cuda_row_num > 0: if self.cuda_row_num > 0:
# Enable cache with introducing auxiliary data structures # Enable cache with introducing auxiliary data structures
...@@ -220,6 +272,10 @@ class CachedParamMgr(torch.nn.Module): ...@@ -220,6 +272,10 @@ class CachedParamMgr(torch.nn.Module):
# new ids chunk_offset + offset_in_chunk # new ids chunk_offset + offset_in_chunk
with record_function("(zhg) embed idx -> cache chunk id"): with record_function("(zhg) embed idx -> cache chunk id"):
gpu_row_idxs = self._id_to_cached_cuda_id(ids) gpu_row_idxs = self._id_to_cached_cuda_id(ids)
# update for LFU.
self._update_freq_cnter(cpu_row_idxs)
return gpu_row_idxs return gpu_row_idxs
def _reset_comm_stats(self): def _reset_comm_stats(self):
...@@ -234,6 +290,7 @@ class CachedParamMgr(torch.nn.Module): ...@@ -234,6 +290,7 @@ class CachedParamMgr(torch.nn.Module):
@torch.no_grad() @torch.no_grad()
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None:
"""prepare rows in cpu_row_idxs on CUDA memory """prepare rows in cpu_row_idxs on CUDA memory
Args: Args:
cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA
""" """
...@@ -245,7 +302,9 @@ class CachedParamMgr(torch.nn.Module): ...@@ -245,7 +302,9 @@ class CachedParamMgr(torch.nn.Module):
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
self.cached_idx_map.index_fill_(0, invalid_idxs, -2) self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
evict_gpu_row_idxs = torch.argsort(self.cached_idx_map, descending=True)[:evict_num]
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
evict_info = self.cached_idx_map[evict_gpu_row_idxs] evict_info = self.cached_idx_map[evict_gpu_row_idxs]
...@@ -291,8 +350,16 @@ class CachedParamMgr(torch.nn.Module): ...@@ -291,8 +350,16 @@ class CachedParamMgr(torch.nn.Module):
self._cpu_to_cuda_numel += weight_size self._cpu_to_cuda_numel += weight_size
# print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB") # print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
def _find_free_cuda_row(self) -> int:
if self._cuda_available_row_num == 0:
return -1
candidates = torch.nonzero(self.cached_idx_map == -1).squeeze(1)
return candidates[0].item()
def _evict(self) -> int: def _evict(self) -> int:
""" """
deprecated
evict one chunk from cuda to cpu. evict one chunk from cuda to cpu.
Returns: Returns:
(int) : the slot id be evicted. (int) : the slot id be evicted.
...@@ -329,15 +396,11 @@ class CachedParamMgr(torch.nn.Module): ...@@ -329,15 +396,11 @@ class CachedParamMgr(torch.nn.Module):
# self.num_write_back_history[-1] += 1 # self.num_write_back_history[-1] += 1
return max_cpu_row_idx return max_cpu_row_idx
def _find_free_cuda_row(self) -> int:
if self._cuda_available_row_num == 0:
return -1
candidates = torch.nonzero(self.cached_idx_map == -1).squeeze(1)
return candidates[0].item()
@torch.no_grad() @torch.no_grad()
def _admit(self, row_id: int): def _admit(self, row_id: int):
""" """
deprecated
move in row_id to CUDA move in row_id to CUDA
Args: Args:
......
...@@ -3,35 +3,35 @@ import torch.nn.functional as F ...@@ -3,35 +3,35 @@ import torch.nn.functional as F
from typing import List, Optional, Iterator, Tuple from typing import List, Optional, Iterator, Tuple
from .base_embedding import BaseEmbeddingBag from .base_embedding import BaseEmbeddingBag
from .cache_mgr import CachedParamMgr from .cache_mgr import CachedParamMgr, EvictionStrategy
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
class FreqAwareEmbeddingBag(BaseEmbeddingBag): class FreqAwareEmbeddingBag(BaseEmbeddingBag):
def __init__( def __init__(self,
self, num_embeddings,
num_embeddings, embedding_dim,
embedding_dim, padding_idx=None,
padding_idx=None, max_norm=None,
max_norm=None, norm_type=2.,
norm_type=2., scale_grad_by_freq=False,
scale_grad_by_freq=False, sparse=False,
sparse=False, _weight=None,
_weight=None, mode='mean',
mode='mean', include_last_offset=False,
include_last_offset=False, dtype=None,
dtype=None, device=None,
device=None, cuda_row_num=0,
cuda_row_num=0, ids_freq_mapping=None,
ids_freq_mapping=None, warmup_ratio=0.7,
warmup_ratio=0.7, buffer_size=50_000,
buffer_size=50_000, pin_weight=False,
pin_weight=False, evict_strategy: EvictionStrategy = EvictionStrategy.DATASET):
):
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
scale_grad_by_freq, sparse, mode, include_last_offset) scale_grad_by_freq, sparse, mode, include_last_offset)
self.evict_strategy = evict_strategy
if _weight is None: if _weight is None:
_weight = self._weight_alloc(dtype, device) _weight = self._weight_alloc(dtype, device)
...@@ -63,7 +63,11 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): ...@@ -63,7 +63,11 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
ids_freq_mapping (List[int]): a list, idx is id number, value is freq ids_freq_mapping (List[int]): a list, idx is id number, value is freq
warmup_ratio (float): the amount of rows preloaded in cuda cache warmup_ratio (float): the amount of rows preloaded in cuda cache
""" """
self.cache_weight_mgr = CachedParamMgr(weight, cuda_row_num, buffer_size, pin_weight) self.cache_weight_mgr = CachedParamMgr(weight,
cuda_row_num,
buffer_size,
pin_weight,
evict_strategy=self.evict_strategy)
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None): def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None):
......
...@@ -12,7 +12,7 @@ from colossalai.utils import free_port ...@@ -12,7 +12,7 @@ from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \ from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
ColoTensor, ColoTensorSpec ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy
NUM_EMBED, EMBED_DIM = 10, 8 NUM_EMBED, EMBED_DIM = 10, 8
BATCH_SIZE = 8 BATCH_SIZE = 8
...@@ -41,6 +41,7 @@ def synthesize_1d_sparse_feature( ...@@ -41,6 +41,7 @@ def synthesize_1d_sparse_feature(
return indices, offsets return indices, offsets
@pytest.mark.skip
def test_cachemgr(): def test_cachemgr():
model = torch.nn.EmbeddingBag(10000, 128) model = torch.nn.EmbeddingBag(10000, 128)
# 10 chunks, 5 in cuda # 10 chunks, 5 in cuda
...@@ -98,14 +99,17 @@ def test_reorder_with_freq(): ...@@ -98,14 +99,17 @@ def test_reorder_with_freq():
f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}"
def test_freq_aware_embed(): @pytest.mark.parametrize('use_LFU', [True, False])
def test_freq_aware_embed(use_LFU: bool):
device = torch.device('cuda', 0) device = torch.device('cuda', 0)
evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET
model = FreqAwareEmbeddingBag(NUM_EMBED, model = FreqAwareEmbeddingBag(NUM_EMBED,
EMBED_DIM, EMBED_DIM,
mode='mean', mode='mean',
include_last_offset=True, include_last_offset=True,
cuda_row_num=BATCH_SIZE * 2, cuda_row_num=BATCH_SIZE * 2,
ids_freq_mapping=None).to(device) ids_freq_mapping=None,
evict_strategy=evict_strategy).to(device)
assert model.weight.shape[0] == NUM_EMBED assert model.weight.shape[0] == NUM_EMBED
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device), ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
...@@ -231,6 +235,5 @@ def test_parallel_freq_aware_embed(world_size): ...@@ -231,6 +235,5 @@ def test_parallel_freq_aware_embed(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_cachemgr() test_freq_aware_embed(True)
# test_freq_aware_embed()
# test_parallel_freq_aware_embed(2) # test_parallel_freq_aware_embed(2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment