Unverified Commit e57df803 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[embeddings] cache option (#1635)

parent a088022e
...@@ -97,12 +97,13 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): ...@@ -97,12 +97,13 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
evict_strategy=self.evict_strategy) evict_strategy=self.evict_strategy)
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None): def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None, cache_op=True):
with torch.no_grad(): if cache_op:
reorder_ids = self.cache_weight_mgr.prepare_ids(input) with torch.no_grad():
input = self.cache_weight_mgr.prepare_ids(input)
embeddings = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, embeddings = F.embedding_bag(input.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset, self.padding_idx) per_sample_weights, self.include_last_offset, self.padding_idx)
if shape_hook is not None: if shape_hook is not None:
embeddings = shape_hook(embeddings) embeddings = shape_hook(embeddings)
......
...@@ -72,11 +72,19 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): ...@@ -72,11 +72,19 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
compute_attr=ComputePattern.TP1D) compute_attr=ComputePattern.TP1D)
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec) return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1): def forward(self,
with torch.no_grad(): indices,
reorder_ids = self.cache_weight_mgr.prepare_ids(indices) offsets=None,
output_shard = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, per_sample_weights=None,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, shape_hook=None,
scatter_dim=0,
gather_dim=-1,
cache_op: bool = True):
if cache_op:
with torch.no_grad():
indices = self.cache_weight_mgr.prepare_ids(indices)
output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset, self.padding_idx) per_sample_weights, self.include_last_offset, self.padding_idx)
if shape_hook is not None: if shape_hook is not None:
output_shard = shape_hook(output_shard) output_shard = shape_hook(output_shard)
......
...@@ -86,7 +86,8 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): ...@@ -86,7 +86,8 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
offsets: torch.Tensor = None, offsets: torch.Tensor = None,
per_sample_weights=None, per_sample_weights=None,
shape_hook=None, shape_hook=None,
already_split_along_rank=True): already_split_along_rank=True,
cache_op=True):
if not already_split_along_rank: if not already_split_along_rank:
# not recommanded. it takes time. # not recommanded. it takes time.
batch_size = (offsets.shape[0]) // self.global_tables_num batch_size = (offsets.shape[0]) // self.global_tables_num
...@@ -96,9 +97,10 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): ...@@ -96,9 +97,10 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
# recommanded. # recommanded.
batch_size = (offsets.shape[0]) // len(self.assigned_table_list) batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
with torch.no_grad(): if cache_op:
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices) with torch.no_grad():
local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, indices = self.cache_weight_mgr.prepare_ids(local_indices)
local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
local_per_sample_weights, self.include_last_offset, self.padding_idx) local_per_sample_weights, self.include_last_offset, self.padding_idx)
local_output = torch.cat(local_output.split(batch_size), 1) local_output = torch.cat(local_output.split(batch_size), 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment