Commit 9c8a2a14 authored by xinghao's avatar xinghao
Browse files

Initial commit

parents
Pipeline #3002 canceled with stages
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from torchrec.datasets.utils import Batch
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
class RestartableMap:
def __init__(self, f, source):
self.source = source
self.func = f
def __iter__(self):
for x in self.source:
yield self.func(x)
def __len__(self):
return len(self.source)
class Multihot:
def __init__(
self,
multi_hot_sizes: list[int],
num_embeddings_per_feature: list[int],
batch_size: int,
collect_freqs_stats: bool,
dist_type: str = "uniform",
):
if dist_type not in {"uniform", "pareto"}:
raise ValueError(
"Multi-hot distribution type {} is not supported."
'Only "uniform" and "pareto" are supported.'.format(dist_type)
)
self.dist_type = dist_type
self.multi_hot_sizes = multi_hot_sizes
self.num_embeddings_per_feature = num_embeddings_per_feature
self.batch_size = batch_size
# Generate 1-hot to multi-hot lookup tables, one lookup table per sparse embedding table.
self.multi_hot_tables_l = self.__make_multi_hot_indices_tables(
dist_type, multi_hot_sizes, num_embeddings_per_feature
)
# Pooling offsets are computed once and reused.
self.offsets = self.__make_offsets(
multi_hot_sizes, num_embeddings_per_feature, batch_size
)
# For plotting frequency access
self.collect_freqs_stats = collect_freqs_stats
self.model_to_track = None
self.freqs_pre_hash = []
self.freqs_post_hash = []
for embs_count in num_embeddings_per_feature:
self.freqs_pre_hash.append(np.zeros(embs_count))
self.freqs_post_hash.append(np.zeros(embs_count))
def save_freqs_stats(self) -> None:
if torch.distributed.is_available() and torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
else:
rank = 0
pre_dict = {str(k): e for k, e in enumerate(self.freqs_pre_hash)}
np.save(f"stats_pre_hash_{rank}_{self.dist_type}.npy", pre_dict)
post_dict = {str(k): e for k, e in enumerate(self.freqs_post_hash)}
np.save(f"stats_post_hash_{rank}_{self.dist_type}.npy", post_dict)
def pause_stats_collection_during_val_and_test(
self, model: torch.nn.Module
) -> None:
self.model_to_track = model
def __make_multi_hot_indices_tables(
self,
dist_type: str,
multi_hot_sizes: list[int],
num_embeddings_per_feature: list[int],
) -> list[np.array]:
np.random.seed(
0
) # The seed is necessary for all ranks to produce the same lookup values.
multi_hot_tables_l = []
for embs_count, multi_hot_size in zip(
num_embeddings_per_feature, multi_hot_sizes
):
embedding_ids = np.arange(embs_count)[:, np.newaxis]
if dist_type == "uniform":
synthetic_sparse_ids = np.random.randint(
0, embs_count, size=(embs_count, multi_hot_size - 1)
)
elif dist_type == "pareto":
synthetic_sparse_ids = (
np.random.pareto(
a=0.25, size=(embs_count, multi_hot_size - 1)
).astype(np.int32)
% embs_count
)
multi_hot_table = np.concatenate(
(embedding_ids, synthetic_sparse_ids), axis=-1
)
multi_hot_tables_l.append(multi_hot_table)
multi_hot_tables_l = [
torch.from_numpy(multi_hot_table).int()
for multi_hot_table in multi_hot_tables_l
]
return multi_hot_tables_l
def __make_offsets(
self,
multi_hot_sizes: int,
num_embeddings_per_feature: list[int],
batch_size: int,
) -> list[torch.Tensor]:
lS_o = torch.ones(
(len(num_embeddings_per_feature) * batch_size), dtype=torch.int32
)
for k, multi_hot_size in enumerate(multi_hot_sizes):
lS_o[k * batch_size : (k + 1) * batch_size] = multi_hot_size
lS_o = torch.cumsum(torch.concat((torch.tensor([0]), lS_o)), axis=0)
return lS_o
def __make_new_batch(
self,
lS_i: torch.Tensor,
batch_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
lS_i = lS_i.reshape(-1, batch_size)
multi_hot_ids_l = []
for k, (sparse_data_batch_for_table, multi_hot_table) in enumerate(
zip(lS_i, self.multi_hot_tables_l)
):
multi_hot_ids = torch.nn.functional.embedding(
sparse_data_batch_for_table, multi_hot_table
)
multi_hot_ids = multi_hot_ids.reshape(-1)
multi_hot_ids_l.append(multi_hot_ids)
if self.collect_freqs_stats and (
self.model_to_track is None or self.model_to_track.training
):
idx_pre, cnt_pre = np.unique(
sparse_data_batch_for_table, return_counts=True
)
idx_post, cnt_post = np.unique(multi_hot_ids, return_counts=True)
self.freqs_pre_hash[k][idx_pre] += cnt_pre
self.freqs_post_hash[k][idx_post] += cnt_post
lS_i = torch.cat(multi_hot_ids_l)
if batch_size == self.batch_size:
return lS_i, self.offsets
else:
return lS_i, self.__make_offsets(
self.multi_hot_sizes, self.num_embeddings_per_feature, batch_size
)
def convert_to_multi_hot(self, batch: Batch) -> Batch:
batch_size = len(batch.dense_features)
lS_i = batch.sparse_features._values
lS_i, lS_o = self.__make_new_batch(lS_i, batch_size)
new_sparse_features = KeyedJaggedTensor.from_offsets_sync(
keys=batch.sparse_features._keys,
values=lS_i,
offsets=lS_o,
)
return Batch(
dense_features=batch.dense_features,
sparse_features=new_sparse_features,
labels=batch.labels,
)
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
base_url="https://storage.googleapis.com/criteo-cail-datasets/day_"
for i in {0..23}; do
url="$base_url$i.gz"
echo Downloading "$url"
wget "$url"
done
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
import pathlib
import shutil
import sys
import numpy as np
import torch
from torch import distributed as dist, nn
from torchrec.datasets.criteo import DAYS
p = pathlib.Path(__file__).absolute().parents[1].resolve()
sys.path.append(os.fspath(p))
# OSS import
try:
# pyre-ignore[21]
# @manual=//ai_codesign/benchmarks/dlrm/torchrec_dlrm:multi_hot
from multi_hot import Multihot
except ImportError:
pass
# internal import
try:
from .multi_hot import Multihot # noqa F811
except ImportError:
pass
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Script to materialize synthetic multi-hot dataset into NumPy npz file format."
)
parser.add_argument(
"--in_memory_binary_criteo_path",
type=str,
required=True,
help="Path to a folder containing the binary (npy) files for the Criteo dataset."
" When supplied, InMemoryBinaryCriteoIterDataPipe is used.",
)
parser.add_argument(
"--output_path",
type=str,
required=True,
help="Path to outputted multi-hot sparse dataset.",
)
parser.add_argument(
"--copy_labels_and_dense",
dest="copy_labels_and_dense",
action="store_true",
help="Flag to determine whether to copy labels and dense data to the output directory.",
)
parser.add_argument(
"--num_embeddings_per_feature",
type=str,
required=True,
help="Comma separated max_ind_size per sparse feature. The number of embeddings"
" in each embedding table. 26 values are expected for the Criteo dataset.",
)
parser.add_argument(
"--multi_hot_sizes",
type=str,
required=True,
help="Comma separated multihot size per sparse feature. 26 values are expected for the Criteo dataset.",
)
parser.add_argument(
"--multi_hot_distribution_type",
type=str,
choices=["uniform", "pareto"],
default="uniform",
help="Multi-hot distribution options.",
)
return parser.parse_args()
def main() -> None:
"""
This script generates and saves the MLPerf v2 multi-hot dataset (4 TB in size).
First, run process_Criteo_1TB_Click_Logs_dataset.sh.
Then, run this script as follows:
python materialize_synthetic_multihot_dataset.py \
--in_memory_binary_criteo_path $PREPROCESSED_CRITEO_1TB_CLICK_LOGS_DATASET_PATH \
--output_path $MATERIALIZED_DATASET_PATH \
--num_embeddings_per_feature 40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36 \
--multi_hot_sizes 3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1 \
--multi_hot_distribution_type uniform
This script takes about 2 hours to run (can be parallelized if needed).
"""
args = parse_args()
for name, val in vars(args).items():
try:
vars(args)[name] = list(map(int, val.split(",")))
except (ValueError, AttributeError):
pass
try:
backend = "nccl" if torch.cuda.is_available() else "gloo"
if not dist.is_initialized():
dist.init_process_group(backend=backend)
rank = dist.get_rank()
world_size = dist.get_world_size()
except (KeyError, ValueError):
rank = 0
world_size = 1
print("Generating one-hot to multi-hot lookup table.")
multihot = Multihot(
multi_hot_sizes=args.multi_hot_sizes,
num_embeddings_per_feature=args.num_embeddings_per_feature,
batch_size=1, # Doesn't matter
collect_freqs_stats=False,
dist_type=args.multi_hot_distribution_type,
)
os.makedirs(args.output_path, exist_ok=True)
for i in range(rank, DAYS, world_size):
input_file_path = os.path.join(
args.in_memory_binary_criteo_path, f"day_{i}_sparse.npy"
)
print(f"Materializing {input_file_path}")
sparse_data = np.load(input_file_path, mmap_mode="r")
multi_hot_ids_dict = {}
for j, (multi_hot_table, hash) in enumerate(
zip(multihot.multi_hot_tables_l, args.num_embeddings_per_feature)
):
sparse_tensor = torch.from_numpy(sparse_data[:, j] % hash)
multi_hot_ids_dict[str(j)] = nn.functional.embedding(
sparse_tensor, multi_hot_table
).numpy()
output_file_path = os.path.join(
args.output_path, f"day_{i}_sparse_multi_hot.npz"
)
np.savez(output_file_path, **multi_hot_ids_dict)
if args.copy_labels_and_dense:
for part in ["labels", "dense"]:
source_path = os.path.join(
args.in_memory_binary_criteo_path, f"day_{i}_{part}.npy"
)
output_path = os.path.join(args.output_path, f"day_{i}_{part}.npy")
shutil.copyfile(source_path, output_path)
print(f"Copying {source_path} to {output_path}")
if __name__ == "__main__":
main()
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
display_help() {
echo "Three command line arguments are required."
echo "Example usage:"
echo "bash process_Criteo_1TB_Click_Logs_dataset.sh \\"
echo "./criteo_1tb/raw_input_dataset_dir \\"
echo "./criteo_1tb/temp_intermediate_files_dir \\"
echo "./criteo_1tb/numpy_contiguous_shuffled_output_dataset_dir"
exit 1
}
[ -z "$1" ] && display_help
[ -z "$2" ] && display_help
[ -z "$3" ] && display_help
# Input directory containing the raw Criteo 1TB Click Logs dataset files in tsv format.
# The 24 dataset filenames in the directory should be day_{0..23} with no .tsv extension.
raw_tsv_criteo_files_dir=$(readlink -m "$1")
# Directory to store temporary intermediate output files created by preprocessing steps 1 and 2.
temp_files_dir=$(readlink -m "$2")
# Directory to store temporary intermediate output files created by preprocessing step 1.
step_1_output_dir="$temp_files_dir/temp_output_of_step_1"
# Directory to store temporary intermediate output files created by preprocessing step 2.
step_2_output_dir="$temp_files_dir/temp_output_of_step_2"
# Directory to store the final preprocessed Criteo 1TB Click Logs dataset.
step_3_output_dir=$(readlink -m "$3")
# Step 1. Split the dataset into 3 sets of 24 numpy files:
# day_{0..23}_dense.npy, day_{0..23}_labels.npy, and day_{0..23}_sparse.npy (~24hrs)
set -x
mkdir -p "$step_1_output_dir"
date
python -m torchrec.datasets.scripts.npy_preproc_criteo --input_dir "$raw_tsv_criteo_files_dir" --output_dir "$step_1_output_dir" || exit
# Step 2. Convert all sparse indices in day_{0..23}_sparse.npy to contiguous indices and save the output.
# The output filenames are day_{0..23}_sparse_contig_freq.npy
mkdir -p "$step_2_output_dir"
date
python -m torchrec.datasets.scripts.contiguous_preproc_criteo --input_dir "$step_1_output_dir" --output_dir "$step_2_output_dir" --frequency_threshold 0 || exit
date
for i in {0..23}
do
name="$step_2_output_dir/day_$i""_sparse_contig_freq.npy"
renamed="$step_2_output_dir/day_$i""_sparse.npy"
echo "Renaming $name to $renamed"
mv "$name" "$renamed"
done
# Step 3. Shuffle the dataset's samples in days 0 through 22. (~20hrs)
# Day 23's samples are not shuffled and will be used for the validation set and test set.
mkdir -p "$step_3_output_dir"
date
python -m torchrec.datasets.scripts.shuffle_preproc_criteo --input_dir_labels_and_dense "$step_1_output_dir" --input_dir_sparse "$step_2_output_dir" --output_dir_shuffled "$step_3_output_dir" --random_seed 0 || exit
date
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import tempfile
import unittest
import uuid
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from torchrec import test_utils
from torchrec.datasets.test_utils.criteo_test_utils import CriteoTest
from ..dlrm_main import main
class MainTest(unittest.TestCase):
@classmethod
def _run_trainer_random(cls) -> None:
main(
[
"--limit_train_batches",
"10",
"--limit_val_batches",
"8",
"--limit_test_batches",
"6",
"--over_arch_layer_sizes",
"8,1",
"--dense_arch_layer_sizes",
"8,8",
"--embedding_dim",
"8",
"--num_embeddings",
"8",
]
)
@test_utils.skip_if_asan
def test_main_function(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
lc = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=2,
run_id=str(uuid.uuid4()),
rdzv_backend="c10d",
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
rdzv_configs={"store_type": "file"},
start_method="spawn",
monitor_interval=1,
max_restarts=0,
)
elastic_launch(config=lc, entrypoint=self._run_trainer_random)()
@classmethod
def _run_trainer_criteo_in_memory(cls) -> None:
with CriteoTest._create_dataset_npys(
num_rows=50, filenames=[f"day_{i}" for i in range(24)]
) as files:
main(
[
"--over_arch_layer_sizes",
"8,1",
"--dense_arch_layer_sizes",
"8,8",
"--embedding_dim",
"8",
"--num_embeddings",
"64",
"--batch_size",
"2",
"--in_memory_binary_criteo_path",
os.path.dirname(files[0]),
"--epochs",
"2",
]
)
@test_utils.skip_if_asan
def test_main_function_criteo_in_memory(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
lc = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=2,
run_id=str(uuid.uuid4()),
rdzv_backend="c10d",
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
rdzv_configs={"store_type": "file"},
start_method="spawn",
monitor_interval=1,
max_restarts=0,
)
elastic_launch(config=lc, entrypoint=self._run_trainer_criteo_in_memory)()
@classmethod
def _run_trainer_dcn(cls) -> None:
with CriteoTest._create_dataset_npys(
num_rows=50, filenames=[f"day_{i}" for i in range(24)]
) as files:
main(
[
"--over_arch_layer_sizes",
"8,1",
"--dense_arch_layer_sizes",
"8,8",
"--embedding_dim",
"8",
"--num_embeddings",
"64",
"--batch_size",
"2",
"--in_memory_binary_criteo_path",
os.path.dirname(files[0]),
"--epochs",
"2",
"--interaction_type",
"dcn",
"--dcn_num_layers",
"2",
"--dcn_low_rank_dim",
"8",
]
)
@test_utils.skip_if_asan
def test_main_function_dcn(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
lc = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=2,
run_id=str(uuid.uuid4()),
rdzv_backend="c10d",
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
rdzv_configs={"store_type": "file"},
start_method="spawn",
monitor_interval=1,
max_restarts=0,
)
elastic_launch(config=lc, entrypoint=self._run_trainer_dcn)()
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Mixed-Dimensions Trick
#
# Description: Applies mixed dimension trick to embeddings to reduce
# embedding sizes.
#
# References:
# [1] Antonio Ginart, Maxim Naumov, Dheevatsa Mudigere, Jiyan Yang, James Zou,
# "Mixed Dimension Embeddings with Application to Memory-Efficient Recommendation
# Systems", CoRR, arXiv:1909.11810, 2019
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch.nn as nn
def md_solver(n, alpha, d0=None, B=None, round_dim=True, k=None):
"""
An external facing function call for mixed-dimension assignment
with the alpha power temperature heuristic
Inputs:
n -- (torch.LongTensor) ; Vector of num of rows for each embedding matrix
alpha -- (torch.FloatTensor); Scalar, non-negative, controls dim. skew
d0 -- (torch.FloatTensor); Scalar, baseline embedding dimension
B -- (torch.FloatTensor); Scalar, parameter budget for embedding layer
round_dim -- (bool); flag for rounding dims to nearest pow of 2
k -- (torch.LongTensor) ; Vector of average number of queries per inference
"""
n, indices = torch.sort(n)
k = k[indices] if k is not None else torch.ones(len(n))
d = alpha_power_rule(n.type(torch.float) / k, alpha, d0=d0, B=B)
if round_dim:
d = pow_2_round(d)
undo_sort = [0] * len(indices)
for i, v in enumerate(indices):
undo_sort[v] = i
return d[undo_sort]
def alpha_power_rule(n, alpha, d0=None, B=None):
if d0 is not None:
lamb = d0 * (n[0].type(torch.float) ** alpha)
elif B is not None:
lamb = B / torch.sum(n.type(torch.float) ** (1 - alpha))
else:
raise ValueError("Must specify either d0 or B")
d = torch.ones(len(n)) * lamb * (n.type(torch.float) ** (-alpha))
for i in range(len(d)):
if i == 0 and d0 is not None:
d[i] = d0
else:
d[i] = 1 if d[i] < 1 else d[i]
return torch.round(d).type(torch.long)
def pow_2_round(dims):
return 2 ** torch.round(torch.log2(dims.type(torch.float)))
class PrEmbeddingBag(nn.Module):
def __init__(self, num_embeddings, embedding_dim, base_dim):
super(PrEmbeddingBag, self).__init__()
self.embs = nn.EmbeddingBag(
num_embeddings, embedding_dim, mode="sum", sparse=True
)
torch.nn.init.xavier_uniform_(self.embs.weight)
if embedding_dim < base_dim:
self.proj = nn.Linear(embedding_dim, base_dim, bias=False)
torch.nn.init.xavier_uniform_(self.proj.weight)
elif embedding_dim == base_dim:
self.proj = nn.Identity()
else:
raise ValueError(
"Embedding dim " + str(embedding_dim) + " > base dim " + str(base_dim)
)
def forward(self, input, offsets=None, per_sample_weights=None):
return self.proj(
self.embs(input, offsets=offsets, per_sample_weights=per_sample_weights)
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Quotient-Remainder Trick
#
# Description: Applies quotient remainder-trick to embeddings to reduce
# embedding sizes.
#
# References:
# [1] Hao-Jun Michael Shi, Dheevatsa Mudigere, Maxim Naumov, Jiyan Yang,
# "Compositional Embeddings Using Complementary Partitions for Memory-Efficient
# Recommendation Systems", CoRR, arXiv:1909.02107, 2019
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
class QREmbeddingBag(nn.Module):
r"""Computes sums or means over two 'bags' of embeddings, one using the quotient
of the indices and the other using the remainder of the indices, without
instantiating the intermediate embeddings, then performs an operation to combine these.
For bags of constant length and no :attr:`per_sample_weights`, this class
* with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=0)``,
* with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=0)``,
* with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=0)``.
However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these
operations.
QREmbeddingBag also supports per-sample weights as an argument to the forward
pass. This scales the output of the Embedding before performing a weighted
reduction as specified by ``mode``. If :attr:`per_sample_weights`` is passed, the
only supported ``mode`` is ``"sum"``, which computes a weighted sum according to
:attr:`per_sample_weights`.
Known Issues:
Autograd breaks with multiple GPUs. It breaks only with multiple embeddings.
Args:
num_categories (int): total number of unique categories. The input indices must be in
0, 1, ..., num_categories - 1.
embedding_dim (list): list of sizes for each embedding vector in each table. If ``"add"``
or ``"mult"`` operation are used, these embedding dimensions must be
the same. If a single embedding_dim is used, then it will use this
embedding_dim for both embedding tables.
num_collisions (int): number of collisions to enforce.
operation (string, optional): ``"concat"``, ``"add"``, or ``"mult". Specifies the operation
to compose embeddings. ``"concat"`` concatenates the embeddings,
``"add"`` sums the embeddings, and ``"mult"`` multiplies
(component-wise) the embeddings.
Default: ``"mult"``
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the inverse of frequency of
the words in the mini-batch. Default ``False``.
Note: this option is not supported when ``mode="max"``.
mode (string, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights`
into consideration. ``"mean"`` computes the average of the values
in the bag, ``"max"`` computes the max value over each bag.
Default: ``"mean"``
sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See
Notes for more details regarding sparse gradients. Note: this option is not
supported when ``mode="max"``.
Attributes:
weight (Tensor): the learnable weights of each embedding table is the module of shape
`(num_embeddings, embedding_dim)` initialized using a uniform distribution
with sqrt(1 / num_categories).
Inputs: :attr:`input` (LongTensor), :attr:`offsets` (LongTensor, optional), and
:attr:`per_index_weights` (Tensor, optional)
- If :attr:`input` is 2D of shape `(B, N)`,
it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and
this will return ``B`` values aggregated in a way depending on the :attr:`mode`.
:attr:`offsets` is ignored and required to be ``None`` in this case.
- If :attr:`input` is 1D of shape `(N)`,
it will be treated as a concatenation of multiple bags (sequences).
:attr:`offsets` is required to be a 1D tensor containing the
starting index positions of each bag in :attr:`input`. Therefore,
for :attr:`offsets` of shape `(B)`, :attr:`input` will be viewed as
having ``B`` bags. Empty bags (i.e., having 0-length) will have
returned vectors filled by zeros.
per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights`
must have exactly the same shape as input and is treated as having the same
:attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``.
Output shape: `(B, embedding_dim)`
"""
__constants__ = [
"num_categories",
"embedding_dim",
"num_collisions",
"operation",
"max_norm",
"norm_type",
"scale_grad_by_freq",
"mode",
"sparse",
]
def __init__(
self,
num_categories,
embedding_dim,
num_collisions,
operation="mult",
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
mode="mean",
sparse=False,
_weight=None,
):
super(QREmbeddingBag, self).__init__()
assert operation in ["concat", "mult", "add"], "Not valid operation!"
self.num_categories = num_categories
if isinstance(embedding_dim, int) or len(embedding_dim) == 1:
self.embedding_dim = [embedding_dim, embedding_dim]
else:
self.embedding_dim = embedding_dim
self.num_collisions = num_collisions
self.operation = operation
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
if self.operation == "add" or self.operation == "mult":
assert (
self.embedding_dim[0] == self.embedding_dim[1]
), "Embedding dimensions do not match!"
self.num_embeddings = [
int(np.ceil(num_categories / num_collisions)),
num_collisions,
]
if _weight is None:
self.weight_q = Parameter(
torch.Tensor(self.num_embeddings[0], self.embedding_dim[0])
)
self.weight_r = Parameter(
torch.Tensor(self.num_embeddings[1], self.embedding_dim[1])
)
self.reset_parameters()
else:
assert (
list(_weight[0].shape)
== [
self.num_embeddings[0],
self.embedding_dim[0],
]
), "Shape of weight for quotient table does not match num_embeddings and embedding_dim"
assert (
list(_weight[1].shape)
== [
self.num_embeddings[1],
self.embedding_dim[1],
]
), "Shape of weight for remainder table does not match num_embeddings and embedding_dim"
self.weight_q = Parameter(_weight[0])
self.weight_r = Parameter(_weight[1])
self.mode = mode
self.sparse = sparse
def reset_parameters(self):
nn.init.uniform_(self.weight_q, np.sqrt(1 / self.num_categories))
nn.init.uniform_(self.weight_r, np.sqrt(1 / self.num_categories))
def forward(self, input, offsets=None, per_sample_weights=None):
input_q = (input / self.num_collisions).long()
input_r = torch.remainder(input, self.num_collisions).long()
embed_q = F.embedding_bag(
input_q,
self.weight_q,
offsets,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.mode,
self.sparse,
per_sample_weights,
)
embed_r = F.embedding_bag(
input_r,
self.weight_r,
offsets,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.mode,
self.sparse,
per_sample_weights,
)
if self.operation == "concat":
embed = torch.cat((embed_q, embed_r), dim=1)
elif self.operation == "add":
embed = embed_q + embed_r
elif self.operation == "mult":
embed = embed_q * embed_r
return embed
def extra_repr(self):
s = "{num_embeddings}, {embedding_dim}"
if self.max_norm is not None:
s += ", max_norm={max_norm}"
if self.norm_type != 2:
s += ", norm_type={norm_type}"
if self.scale_grad_by_freq is not False:
s += ", scale_grad_by_freq={scale_grad_by_freq}"
s += ", mode={mode}"
return s.format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment