Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Ecological Empowerment
dlrm
Commits
9c8a2a14
Commit
9c8a2a14
authored
Oct 21, 2025
by
xinghao
Browse files
Initial commit
parents
Pipeline
#3002
canceled with stages
Changes
48
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
874 additions
and
0 deletions
+874
-0
torchrec_dlrm/multi_hot.py
torchrec_dlrm/multi_hot.py
+174
-0
torchrec_dlrm/requirements.txt
torchrec_dlrm/requirements.txt
+2
-0
torchrec_dlrm/scripts/download_Criteo_1TB_Click_Logs_dataset.sh
...ec_dlrm/scripts/download_Criteo_1TB_Click_Logs_dataset.sh
+12
-0
torchrec_dlrm/scripts/materialize_synthetic_multihot_dataset.py
...ec_dlrm/scripts/materialize_synthetic_multihot_dataset.py
+153
-0
torchrec_dlrm/scripts/process_Criteo_1TB_Click_Logs_dataset.sh
...rec_dlrm/scripts/process_Criteo_1TB_Click_Logs_dataset.sh
+64
-0
torchrec_dlrm/tests/test_dlrm_main.py
torchrec_dlrm/tests/test_dlrm_main.py
+147
-0
tricks/md_embedding_bag.py
tricks/md_embedding_bag.py
+85
-0
tricks/qr_embedding_bag.py
tricks/qr_embedding_bag.py
+237
-0
No files found.
torchrec_dlrm/multi_hot.py
0 → 100644
View file @
9c8a2a14
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
numpy
as
np
import
torch
from
torchrec.datasets.utils
import
Batch
from
torchrec.sparse.jagged_tensor
import
KeyedJaggedTensor
class
RestartableMap
:
def
__init__
(
self
,
f
,
source
):
self
.
source
=
source
self
.
func
=
f
def
__iter__
(
self
):
for
x
in
self
.
source
:
yield
self
.
func
(
x
)
def
__len__
(
self
):
return
len
(
self
.
source
)
class
Multihot
:
def
__init__
(
self
,
multi_hot_sizes
:
list
[
int
],
num_embeddings_per_feature
:
list
[
int
],
batch_size
:
int
,
collect_freqs_stats
:
bool
,
dist_type
:
str
=
"uniform"
,
):
if
dist_type
not
in
{
"uniform"
,
"pareto"
}:
raise
ValueError
(
"Multi-hot distribution type {} is not supported."
'Only "uniform" and "pareto" are supported.'
.
format
(
dist_type
)
)
self
.
dist_type
=
dist_type
self
.
multi_hot_sizes
=
multi_hot_sizes
self
.
num_embeddings_per_feature
=
num_embeddings_per_feature
self
.
batch_size
=
batch_size
# Generate 1-hot to multi-hot lookup tables, one lookup table per sparse embedding table.
self
.
multi_hot_tables_l
=
self
.
__make_multi_hot_indices_tables
(
dist_type
,
multi_hot_sizes
,
num_embeddings_per_feature
)
# Pooling offsets are computed once and reused.
self
.
offsets
=
self
.
__make_offsets
(
multi_hot_sizes
,
num_embeddings_per_feature
,
batch_size
)
# For plotting frequency access
self
.
collect_freqs_stats
=
collect_freqs_stats
self
.
model_to_track
=
None
self
.
freqs_pre_hash
=
[]
self
.
freqs_post_hash
=
[]
for
embs_count
in
num_embeddings_per_feature
:
self
.
freqs_pre_hash
.
append
(
np
.
zeros
(
embs_count
))
self
.
freqs_post_hash
.
append
(
np
.
zeros
(
embs_count
))
def
save_freqs_stats
(
self
)
->
None
:
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
rank
=
torch
.
distributed
.
get_rank
()
else
:
rank
=
0
pre_dict
=
{
str
(
k
):
e
for
k
,
e
in
enumerate
(
self
.
freqs_pre_hash
)}
np
.
save
(
f
"stats_pre_hash_
{
rank
}
_
{
self
.
dist_type
}
.npy"
,
pre_dict
)
post_dict
=
{
str
(
k
):
e
for
k
,
e
in
enumerate
(
self
.
freqs_post_hash
)}
np
.
save
(
f
"stats_post_hash_
{
rank
}
_
{
self
.
dist_type
}
.npy"
,
post_dict
)
def
pause_stats_collection_during_val_and_test
(
self
,
model
:
torch
.
nn
.
Module
)
->
None
:
self
.
model_to_track
=
model
def
__make_multi_hot_indices_tables
(
self
,
dist_type
:
str
,
multi_hot_sizes
:
list
[
int
],
num_embeddings_per_feature
:
list
[
int
],
)
->
list
[
np
.
array
]:
np
.
random
.
seed
(
0
)
# The seed is necessary for all ranks to produce the same lookup values.
multi_hot_tables_l
=
[]
for
embs_count
,
multi_hot_size
in
zip
(
num_embeddings_per_feature
,
multi_hot_sizes
):
embedding_ids
=
np
.
arange
(
embs_count
)[:,
np
.
newaxis
]
if
dist_type
==
"uniform"
:
synthetic_sparse_ids
=
np
.
random
.
randint
(
0
,
embs_count
,
size
=
(
embs_count
,
multi_hot_size
-
1
)
)
elif
dist_type
==
"pareto"
:
synthetic_sparse_ids
=
(
np
.
random
.
pareto
(
a
=
0.25
,
size
=
(
embs_count
,
multi_hot_size
-
1
)
).
astype
(
np
.
int32
)
%
embs_count
)
multi_hot_table
=
np
.
concatenate
(
(
embedding_ids
,
synthetic_sparse_ids
),
axis
=-
1
)
multi_hot_tables_l
.
append
(
multi_hot_table
)
multi_hot_tables_l
=
[
torch
.
from_numpy
(
multi_hot_table
).
int
()
for
multi_hot_table
in
multi_hot_tables_l
]
return
multi_hot_tables_l
def
__make_offsets
(
self
,
multi_hot_sizes
:
int
,
num_embeddings_per_feature
:
list
[
int
],
batch_size
:
int
,
)
->
list
[
torch
.
Tensor
]:
lS_o
=
torch
.
ones
(
(
len
(
num_embeddings_per_feature
)
*
batch_size
),
dtype
=
torch
.
int32
)
for
k
,
multi_hot_size
in
enumerate
(
multi_hot_sizes
):
lS_o
[
k
*
batch_size
:
(
k
+
1
)
*
batch_size
]
=
multi_hot_size
lS_o
=
torch
.
cumsum
(
torch
.
concat
((
torch
.
tensor
([
0
]),
lS_o
)),
axis
=
0
)
return
lS_o
def
__make_new_batch
(
self
,
lS_i
:
torch
.
Tensor
,
batch_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
lS_i
=
lS_i
.
reshape
(
-
1
,
batch_size
)
multi_hot_ids_l
=
[]
for
k
,
(
sparse_data_batch_for_table
,
multi_hot_table
)
in
enumerate
(
zip
(
lS_i
,
self
.
multi_hot_tables_l
)
):
multi_hot_ids
=
torch
.
nn
.
functional
.
embedding
(
sparse_data_batch_for_table
,
multi_hot_table
)
multi_hot_ids
=
multi_hot_ids
.
reshape
(
-
1
)
multi_hot_ids_l
.
append
(
multi_hot_ids
)
if
self
.
collect_freqs_stats
and
(
self
.
model_to_track
is
None
or
self
.
model_to_track
.
training
):
idx_pre
,
cnt_pre
=
np
.
unique
(
sparse_data_batch_for_table
,
return_counts
=
True
)
idx_post
,
cnt_post
=
np
.
unique
(
multi_hot_ids
,
return_counts
=
True
)
self
.
freqs_pre_hash
[
k
][
idx_pre
]
+=
cnt_pre
self
.
freqs_post_hash
[
k
][
idx_post
]
+=
cnt_post
lS_i
=
torch
.
cat
(
multi_hot_ids_l
)
if
batch_size
==
self
.
batch_size
:
return
lS_i
,
self
.
offsets
else
:
return
lS_i
,
self
.
__make_offsets
(
self
.
multi_hot_sizes
,
self
.
num_embeddings_per_feature
,
batch_size
)
def
convert_to_multi_hot
(
self
,
batch
:
Batch
)
->
Batch
:
batch_size
=
len
(
batch
.
dense_features
)
lS_i
=
batch
.
sparse_features
.
_values
lS_i
,
lS_o
=
self
.
__make_new_batch
(
lS_i
,
batch_size
)
new_sparse_features
=
KeyedJaggedTensor
.
from_offsets_sync
(
keys
=
batch
.
sparse_features
.
_keys
,
values
=
lS_i
,
offsets
=
lS_o
,
)
return
Batch
(
dense_features
=
batch
.
dense_features
,
sparse_features
=
new_sparse_features
,
labels
=
batch
.
labels
,
)
torchrec_dlrm/requirements.txt
0 → 100644
View file @
9c8a2a14
tqdm
torchmetrics
torchrec_dlrm/scripts/download_Criteo_1TB_Click_Logs_dataset.sh
0 → 100644
View file @
9c8a2a14
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
base_url
=
"https://storage.googleapis.com/criteo-cail-datasets/day_"
for
i
in
{
0..23
}
;
do
url
=
"
$base_url$i
.gz"
echo
Downloading
"
$url
"
wget
"
$url
"
done
torchrec_dlrm/scripts/materialize_synthetic_multihot_dataset.py
0 → 100644
View file @
9c8a2a14
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
argparse
import
os
import
pathlib
import
shutil
import
sys
import
numpy
as
np
import
torch
from
torch
import
distributed
as
dist
,
nn
from
torchrec.datasets.criteo
import
DAYS
p
=
pathlib
.
Path
(
__file__
).
absolute
().
parents
[
1
].
resolve
()
sys
.
path
.
append
(
os
.
fspath
(
p
))
# OSS import
try
:
# pyre-ignore[21]
# @manual=//ai_codesign/benchmarks/dlrm/torchrec_dlrm:multi_hot
from
multi_hot
import
Multihot
except
ImportError
:
pass
# internal import
try
:
from
.multi_hot
import
Multihot
# noqa F811
except
ImportError
:
pass
def
parse_args
()
->
argparse
.
Namespace
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Script to materialize synthetic multi-hot dataset into NumPy npz file format."
)
parser
.
add_argument
(
"--in_memory_binary_criteo_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to a folder containing the binary (npy) files for the Criteo dataset."
" When supplied, InMemoryBinaryCriteoIterDataPipe is used."
,
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to outputted multi-hot sparse dataset."
,
)
parser
.
add_argument
(
"--copy_labels_and_dense"
,
dest
=
"copy_labels_and_dense"
,
action
=
"store_true"
,
help
=
"Flag to determine whether to copy labels and dense data to the output directory."
,
)
parser
.
add_argument
(
"--num_embeddings_per_feature"
,
type
=
str
,
required
=
True
,
help
=
"Comma separated max_ind_size per sparse feature. The number of embeddings"
" in each embedding table. 26 values are expected for the Criteo dataset."
,
)
parser
.
add_argument
(
"--multi_hot_sizes"
,
type
=
str
,
required
=
True
,
help
=
"Comma separated multihot size per sparse feature. 26 values are expected for the Criteo dataset."
,
)
parser
.
add_argument
(
"--multi_hot_distribution_type"
,
type
=
str
,
choices
=
[
"uniform"
,
"pareto"
],
default
=
"uniform"
,
help
=
"Multi-hot distribution options."
,
)
return
parser
.
parse_args
()
def
main
()
->
None
:
"""
This script generates and saves the MLPerf v2 multi-hot dataset (4 TB in size).
First, run process_Criteo_1TB_Click_Logs_dataset.sh.
Then, run this script as follows:
python materialize_synthetic_multihot_dataset.py
\
--in_memory_binary_criteo_path $PREPROCESSED_CRITEO_1TB_CLICK_LOGS_DATASET_PATH
\
--output_path $MATERIALIZED_DATASET_PATH
\
--num_embeddings_per_feature 40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36
\
--multi_hot_sizes 3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1
\
--multi_hot_distribution_type uniform
This script takes about 2 hours to run (can be parallelized if needed).
"""
args
=
parse_args
()
for
name
,
val
in
vars
(
args
).
items
():
try
:
vars
(
args
)[
name
]
=
list
(
map
(
int
,
val
.
split
(
","
)))
except
(
ValueError
,
AttributeError
):
pass
try
:
backend
=
"nccl"
if
torch
.
cuda
.
is_available
()
else
"gloo"
if
not
dist
.
is_initialized
():
dist
.
init_process_group
(
backend
=
backend
)
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
except
(
KeyError
,
ValueError
):
rank
=
0
world_size
=
1
print
(
"Generating one-hot to multi-hot lookup table."
)
multihot
=
Multihot
(
multi_hot_sizes
=
args
.
multi_hot_sizes
,
num_embeddings_per_feature
=
args
.
num_embeddings_per_feature
,
batch_size
=
1
,
# Doesn't matter
collect_freqs_stats
=
False
,
dist_type
=
args
.
multi_hot_distribution_type
,
)
os
.
makedirs
(
args
.
output_path
,
exist_ok
=
True
)
for
i
in
range
(
rank
,
DAYS
,
world_size
):
input_file_path
=
os
.
path
.
join
(
args
.
in_memory_binary_criteo_path
,
f
"day_
{
i
}
_sparse.npy"
)
print
(
f
"Materializing
{
input_file_path
}
"
)
sparse_data
=
np
.
load
(
input_file_path
,
mmap_mode
=
"r"
)
multi_hot_ids_dict
=
{}
for
j
,
(
multi_hot_table
,
hash
)
in
enumerate
(
zip
(
multihot
.
multi_hot_tables_l
,
args
.
num_embeddings_per_feature
)
):
sparse_tensor
=
torch
.
from_numpy
(
sparse_data
[:,
j
]
%
hash
)
multi_hot_ids_dict
[
str
(
j
)]
=
nn
.
functional
.
embedding
(
sparse_tensor
,
multi_hot_table
).
numpy
()
output_file_path
=
os
.
path
.
join
(
args
.
output_path
,
f
"day_
{
i
}
_sparse_multi_hot.npz"
)
np
.
savez
(
output_file_path
,
**
multi_hot_ids_dict
)
if
args
.
copy_labels_and_dense
:
for
part
in
[
"labels"
,
"dense"
]:
source_path
=
os
.
path
.
join
(
args
.
in_memory_binary_criteo_path
,
f
"day_
{
i
}
_
{
part
}
.npy"
)
output_path
=
os
.
path
.
join
(
args
.
output_path
,
f
"day_
{
i
}
_
{
part
}
.npy"
)
shutil
.
copyfile
(
source_path
,
output_path
)
print
(
f
"Copying
{
source_path
}
to
{
output_path
}
"
)
if
__name__
==
"__main__"
:
main
()
torchrec_dlrm/scripts/process_Criteo_1TB_Click_Logs_dataset.sh
0 → 100644
View file @
9c8a2a14
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
display_help
()
{
echo
"Three command line arguments are required."
echo
"Example usage:"
echo
"bash process_Criteo_1TB_Click_Logs_dataset.sh
\\
"
echo
"./criteo_1tb/raw_input_dataset_dir
\\
"
echo
"./criteo_1tb/temp_intermediate_files_dir
\\
"
echo
"./criteo_1tb/numpy_contiguous_shuffled_output_dataset_dir"
exit
1
}
[
-z
"
$1
"
]
&&
display_help
[
-z
"
$2
"
]
&&
display_help
[
-z
"
$3
"
]
&&
display_help
# Input directory containing the raw Criteo 1TB Click Logs dataset files in tsv format.
# The 24 dataset filenames in the directory should be day_{0..23} with no .tsv extension.
raw_tsv_criteo_files_dir
=
$(
readlink
-m
"
$1
"
)
# Directory to store temporary intermediate output files created by preprocessing steps 1 and 2.
temp_files_dir
=
$(
readlink
-m
"
$2
"
)
# Directory to store temporary intermediate output files created by preprocessing step 1.
step_1_output_dir
=
"
$temp_files_dir
/temp_output_of_step_1"
# Directory to store temporary intermediate output files created by preprocessing step 2.
step_2_output_dir
=
"
$temp_files_dir
/temp_output_of_step_2"
# Directory to store the final preprocessed Criteo 1TB Click Logs dataset.
step_3_output_dir
=
$(
readlink
-m
"
$3
"
)
# Step 1. Split the dataset into 3 sets of 24 numpy files:
# day_{0..23}_dense.npy, day_{0..23}_labels.npy, and day_{0..23}_sparse.npy (~24hrs)
set
-x
mkdir
-p
"
$step_1_output_dir
"
date
python
-m
torchrec.datasets.scripts.npy_preproc_criteo
--input_dir
"
$raw_tsv_criteo_files_dir
"
--output_dir
"
$step_1_output_dir
"
||
exit
# Step 2. Convert all sparse indices in day_{0..23}_sparse.npy to contiguous indices and save the output.
# The output filenames are day_{0..23}_sparse_contig_freq.npy
mkdir
-p
"
$step_2_output_dir
"
date
python
-m
torchrec.datasets.scripts.contiguous_preproc_criteo
--input_dir
"
$step_1_output_dir
"
--output_dir
"
$step_2_output_dir
"
--frequency_threshold
0
||
exit
date
for
i
in
{
0..23
}
do
name
=
"
$step_2_output_dir
/day_
$i
""_sparse_contig_freq.npy"
renamed
=
"
$step_2_output_dir
/day_
$i
""_sparse.npy"
echo
"Renaming
$name
to
$renamed
"
mv
"
$name
"
"
$renamed
"
done
# Step 3. Shuffle the dataset's samples in days 0 through 22. (~20hrs)
# Day 23's samples are not shuffled and will be used for the validation set and test set.
mkdir
-p
"
$step_3_output_dir
"
date
python
-m
torchrec.datasets.scripts.shuffle_preproc_criteo
--input_dir_labels_and_dense
"
$step_1_output_dir
"
--input_dir_sparse
"
$step_2_output_dir
"
--output_dir_shuffled
"
$step_3_output_dir
"
--random_seed
0
||
exit
date
torchrec_dlrm/tests/test_dlrm_main.py
0 → 100644
View file @
9c8a2a14
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
tempfile
import
unittest
import
uuid
from
torch.distributed.launcher.api
import
elastic_launch
,
LaunchConfig
from
torchrec
import
test_utils
from
torchrec.datasets.test_utils.criteo_test_utils
import
CriteoTest
from
..dlrm_main
import
main
class
MainTest
(
unittest
.
TestCase
):
@
classmethod
def
_run_trainer_random
(
cls
)
->
None
:
main
(
[
"--limit_train_batches"
,
"10"
,
"--limit_val_batches"
,
"8"
,
"--limit_test_batches"
,
"6"
,
"--over_arch_layer_sizes"
,
"8,1"
,
"--dense_arch_layer_sizes"
,
"8,8"
,
"--embedding_dim"
,
"8"
,
"--num_embeddings"
,
"8"
,
]
)
@
test_utils
.
skip_if_asan
def
test_main_function
(
self
)
->
None
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
lc
=
LaunchConfig
(
min_nodes
=
1
,
max_nodes
=
1
,
nproc_per_node
=
2
,
run_id
=
str
(
uuid
.
uuid4
()),
rdzv_backend
=
"c10d"
,
rdzv_endpoint
=
os
.
path
.
join
(
tmpdir
,
"rdzv"
),
rdzv_configs
=
{
"store_type"
:
"file"
},
start_method
=
"spawn"
,
monitor_interval
=
1
,
max_restarts
=
0
,
)
elastic_launch
(
config
=
lc
,
entrypoint
=
self
.
_run_trainer_random
)()
@
classmethod
def
_run_trainer_criteo_in_memory
(
cls
)
->
None
:
with
CriteoTest
.
_create_dataset_npys
(
num_rows
=
50
,
filenames
=
[
f
"day_
{
i
}
"
for
i
in
range
(
24
)]
)
as
files
:
main
(
[
"--over_arch_layer_sizes"
,
"8,1"
,
"--dense_arch_layer_sizes"
,
"8,8"
,
"--embedding_dim"
,
"8"
,
"--num_embeddings"
,
"64"
,
"--batch_size"
,
"2"
,
"--in_memory_binary_criteo_path"
,
os
.
path
.
dirname
(
files
[
0
]),
"--epochs"
,
"2"
,
]
)
@
test_utils
.
skip_if_asan
def
test_main_function_criteo_in_memory
(
self
)
->
None
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
lc
=
LaunchConfig
(
min_nodes
=
1
,
max_nodes
=
1
,
nproc_per_node
=
2
,
run_id
=
str
(
uuid
.
uuid4
()),
rdzv_backend
=
"c10d"
,
rdzv_endpoint
=
os
.
path
.
join
(
tmpdir
,
"rdzv"
),
rdzv_configs
=
{
"store_type"
:
"file"
},
start_method
=
"spawn"
,
monitor_interval
=
1
,
max_restarts
=
0
,
)
elastic_launch
(
config
=
lc
,
entrypoint
=
self
.
_run_trainer_criteo_in_memory
)()
@
classmethod
def
_run_trainer_dcn
(
cls
)
->
None
:
with
CriteoTest
.
_create_dataset_npys
(
num_rows
=
50
,
filenames
=
[
f
"day_
{
i
}
"
for
i
in
range
(
24
)]
)
as
files
:
main
(
[
"--over_arch_layer_sizes"
,
"8,1"
,
"--dense_arch_layer_sizes"
,
"8,8"
,
"--embedding_dim"
,
"8"
,
"--num_embeddings"
,
"64"
,
"--batch_size"
,
"2"
,
"--in_memory_binary_criteo_path"
,
os
.
path
.
dirname
(
files
[
0
]),
"--epochs"
,
"2"
,
"--interaction_type"
,
"dcn"
,
"--dcn_num_layers"
,
"2"
,
"--dcn_low_rank_dim"
,
"8"
,
]
)
@
test_utils
.
skip_if_asan
def
test_main_function_dcn
(
self
)
->
None
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
lc
=
LaunchConfig
(
min_nodes
=
1
,
max_nodes
=
1
,
nproc_per_node
=
2
,
run_id
=
str
(
uuid
.
uuid4
()),
rdzv_backend
=
"c10d"
,
rdzv_endpoint
=
os
.
path
.
join
(
tmpdir
,
"rdzv"
),
rdzv_configs
=
{
"store_type"
:
"file"
},
start_method
=
"spawn"
,
monitor_interval
=
1
,
max_restarts
=
0
,
)
elastic_launch
(
config
=
lc
,
entrypoint
=
self
.
_run_trainer_dcn
)()
tricks/md_embedding_bag.py
0 → 100644
View file @
9c8a2a14
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Mixed-Dimensions Trick
#
# Description: Applies mixed dimension trick to embeddings to reduce
# embedding sizes.
#
# References:
# [1] Antonio Ginart, Maxim Naumov, Dheevatsa Mudigere, Jiyan Yang, James Zou,
# "Mixed Dimension Embeddings with Application to Memory-Efficient Recommendation
# Systems", CoRR, arXiv:1909.11810, 2019
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
torch
import
torch.nn
as
nn
def
md_solver
(
n
,
alpha
,
d0
=
None
,
B
=
None
,
round_dim
=
True
,
k
=
None
):
"""
An external facing function call for mixed-dimension assignment
with the alpha power temperature heuristic
Inputs:
n -- (torch.LongTensor) ; Vector of num of rows for each embedding matrix
alpha -- (torch.FloatTensor); Scalar, non-negative, controls dim. skew
d0 -- (torch.FloatTensor); Scalar, baseline embedding dimension
B -- (torch.FloatTensor); Scalar, parameter budget for embedding layer
round_dim -- (bool); flag for rounding dims to nearest pow of 2
k -- (torch.LongTensor) ; Vector of average number of queries per inference
"""
n
,
indices
=
torch
.
sort
(
n
)
k
=
k
[
indices
]
if
k
is
not
None
else
torch
.
ones
(
len
(
n
))
d
=
alpha_power_rule
(
n
.
type
(
torch
.
float
)
/
k
,
alpha
,
d0
=
d0
,
B
=
B
)
if
round_dim
:
d
=
pow_2_round
(
d
)
undo_sort
=
[
0
]
*
len
(
indices
)
for
i
,
v
in
enumerate
(
indices
):
undo_sort
[
v
]
=
i
return
d
[
undo_sort
]
def
alpha_power_rule
(
n
,
alpha
,
d0
=
None
,
B
=
None
):
if
d0
is
not
None
:
lamb
=
d0
*
(
n
[
0
].
type
(
torch
.
float
)
**
alpha
)
elif
B
is
not
None
:
lamb
=
B
/
torch
.
sum
(
n
.
type
(
torch
.
float
)
**
(
1
-
alpha
))
else
:
raise
ValueError
(
"Must specify either d0 or B"
)
d
=
torch
.
ones
(
len
(
n
))
*
lamb
*
(
n
.
type
(
torch
.
float
)
**
(
-
alpha
))
for
i
in
range
(
len
(
d
)):
if
i
==
0
and
d0
is
not
None
:
d
[
i
]
=
d0
else
:
d
[
i
]
=
1
if
d
[
i
]
<
1
else
d
[
i
]
return
torch
.
round
(
d
).
type
(
torch
.
long
)
def
pow_2_round
(
dims
):
return
2
**
torch
.
round
(
torch
.
log2
(
dims
.
type
(
torch
.
float
)))
class
PrEmbeddingBag
(
nn
.
Module
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
base_dim
):
super
(
PrEmbeddingBag
,
self
).
__init__
()
self
.
embs
=
nn
.
EmbeddingBag
(
num_embeddings
,
embedding_dim
,
mode
=
"sum"
,
sparse
=
True
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
embs
.
weight
)
if
embedding_dim
<
base_dim
:
self
.
proj
=
nn
.
Linear
(
embedding_dim
,
base_dim
,
bias
=
False
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
proj
.
weight
)
elif
embedding_dim
==
base_dim
:
self
.
proj
=
nn
.
Identity
()
else
:
raise
ValueError
(
"Embedding dim "
+
str
(
embedding_dim
)
+
" > base dim "
+
str
(
base_dim
)
)
def
forward
(
self
,
input
,
offsets
=
None
,
per_sample_weights
=
None
):
return
self
.
proj
(
self
.
embs
(
input
,
offsets
=
offsets
,
per_sample_weights
=
per_sample_weights
)
)
tricks/qr_embedding_bag.py
0 → 100644
View file @
9c8a2a14
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Quotient-Remainder Trick
#
# Description: Applies quotient remainder-trick to embeddings to reduce
# embedding sizes.
#
# References:
# [1] Hao-Jun Michael Shi, Dheevatsa Mudigere, Maxim Naumov, Jiyan Yang,
# "Compositional Embeddings Using Complementary Partitions for Memory-Efficient
# Recommendation Systems", CoRR, arXiv:1909.02107, 2019
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
class
QREmbeddingBag
(
nn
.
Module
):
r
"""Computes sums or means over two 'bags' of embeddings, one using the quotient
of the indices and the other using the remainder of the indices, without
instantiating the intermediate embeddings, then performs an operation to combine these.
For bags of constant length and no :attr:`per_sample_weights`, this class
* with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=0)``,
* with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=0)``,
* with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=0)``.
However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these
operations.
QREmbeddingBag also supports per-sample weights as an argument to the forward
pass. This scales the output of the Embedding before performing a weighted
reduction as specified by ``mode``. If :attr:`per_sample_weights`` is passed, the
only supported ``mode`` is ``"sum"``, which computes a weighted sum according to
:attr:`per_sample_weights`.
Known Issues:
Autograd breaks with multiple GPUs. It breaks only with multiple embeddings.
Args:
num_categories (int): total number of unique categories. The input indices must be in
0, 1, ..., num_categories - 1.
embedding_dim (list): list of sizes for each embedding vector in each table. If ``"add"``
or ``"mult"`` operation are used, these embedding dimensions must be
the same. If a single embedding_dim is used, then it will use this
embedding_dim for both embedding tables.
num_collisions (int): number of collisions to enforce.
operation (string, optional): ``"concat"``, ``"add"``, or ``"mult". Specifies the operation
to compose embeddings. ``"concat"`` concatenates the embeddings,
``"add"`` sums the embeddings, and ``"mult"`` multiplies
(component-wise) the embeddings.
Default: ``"mult"``
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the inverse of frequency of
the words in the mini-batch. Default ``False``.
Note: this option is not supported when ``mode="max"``.
mode (string, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights`
into consideration. ``"mean"`` computes the average of the values
in the bag, ``"max"`` computes the max value over each bag.
Default: ``"mean"``
sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See
Notes for more details regarding sparse gradients. Note: this option is not
supported when ``mode="max"``.
Attributes:
weight (Tensor): the learnable weights of each embedding table is the module of shape
`(num_embeddings, embedding_dim)` initialized using a uniform distribution
with sqrt(1 / num_categories).
Inputs: :attr:`input` (LongTensor), :attr:`offsets` (LongTensor, optional), and
:attr:`per_index_weights` (Tensor, optional)
- If :attr:`input` is 2D of shape `(B, N)`,
it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and
this will return ``B`` values aggregated in a way depending on the :attr:`mode`.
:attr:`offsets` is ignored and required to be ``None`` in this case.
- If :attr:`input` is 1D of shape `(N)`,
it will be treated as a concatenation of multiple bags (sequences).
:attr:`offsets` is required to be a 1D tensor containing the
starting index positions of each bag in :attr:`input`. Therefore,
for :attr:`offsets` of shape `(B)`, :attr:`input` will be viewed as
having ``B`` bags. Empty bags (i.e., having 0-length) will have
returned vectors filled by zeros.
per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights`
must have exactly the same shape as input and is treated as having the same
:attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``.
Output shape: `(B, embedding_dim)`
"""
__constants__
=
[
"num_categories"
,
"embedding_dim"
,
"num_collisions"
,
"operation"
,
"max_norm"
,
"norm_type"
,
"scale_grad_by_freq"
,
"mode"
,
"sparse"
,
]
def
__init__
(
self
,
num_categories
,
embedding_dim
,
num_collisions
,
operation
=
"mult"
,
max_norm
=
None
,
norm_type
=
2.0
,
scale_grad_by_freq
=
False
,
mode
=
"mean"
,
sparse
=
False
,
_weight
=
None
,
):
super
(
QREmbeddingBag
,
self
).
__init__
()
assert
operation
in
[
"concat"
,
"mult"
,
"add"
],
"Not valid operation!"
self
.
num_categories
=
num_categories
if
isinstance
(
embedding_dim
,
int
)
or
len
(
embedding_dim
)
==
1
:
self
.
embedding_dim
=
[
embedding_dim
,
embedding_dim
]
else
:
self
.
embedding_dim
=
embedding_dim
self
.
num_collisions
=
num_collisions
self
.
operation
=
operation
self
.
max_norm
=
max_norm
self
.
norm_type
=
norm_type
self
.
scale_grad_by_freq
=
scale_grad_by_freq
if
self
.
operation
==
"add"
or
self
.
operation
==
"mult"
:
assert
(
self
.
embedding_dim
[
0
]
==
self
.
embedding_dim
[
1
]
),
"Embedding dimensions do not match!"
self
.
num_embeddings
=
[
int
(
np
.
ceil
(
num_categories
/
num_collisions
)),
num_collisions
,
]
if
_weight
is
None
:
self
.
weight_q
=
Parameter
(
torch
.
Tensor
(
self
.
num_embeddings
[
0
],
self
.
embedding_dim
[
0
])
)
self
.
weight_r
=
Parameter
(
torch
.
Tensor
(
self
.
num_embeddings
[
1
],
self
.
embedding_dim
[
1
])
)
self
.
reset_parameters
()
else
:
assert
(
list
(
_weight
[
0
].
shape
)
==
[
self
.
num_embeddings
[
0
],
self
.
embedding_dim
[
0
],
]
),
"Shape of weight for quotient table does not match num_embeddings and embedding_dim"
assert
(
list
(
_weight
[
1
].
shape
)
==
[
self
.
num_embeddings
[
1
],
self
.
embedding_dim
[
1
],
]
),
"Shape of weight for remainder table does not match num_embeddings and embedding_dim"
self
.
weight_q
=
Parameter
(
_weight
[
0
])
self
.
weight_r
=
Parameter
(
_weight
[
1
])
self
.
mode
=
mode
self
.
sparse
=
sparse
def
reset_parameters
(
self
):
nn
.
init
.
uniform_
(
self
.
weight_q
,
np
.
sqrt
(
1
/
self
.
num_categories
))
nn
.
init
.
uniform_
(
self
.
weight_r
,
np
.
sqrt
(
1
/
self
.
num_categories
))
def
forward
(
self
,
input
,
offsets
=
None
,
per_sample_weights
=
None
):
input_q
=
(
input
/
self
.
num_collisions
).
long
()
input_r
=
torch
.
remainder
(
input
,
self
.
num_collisions
).
long
()
embed_q
=
F
.
embedding_bag
(
input_q
,
self
.
weight_q
,
offsets
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
per_sample_weights
,
)
embed_r
=
F
.
embedding_bag
(
input_r
,
self
.
weight_r
,
offsets
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
mode
,
self
.
sparse
,
per_sample_weights
,
)
if
self
.
operation
==
"concat"
:
embed
=
torch
.
cat
((
embed_q
,
embed_r
),
dim
=
1
)
elif
self
.
operation
==
"add"
:
embed
=
embed_q
+
embed_r
elif
self
.
operation
==
"mult"
:
embed
=
embed_q
*
embed_r
return
embed
def
extra_repr
(
self
):
s
=
"{num_embeddings}, {embedding_dim}"
if
self
.
max_norm
is
not
None
:
s
+=
", max_norm={max_norm}"
if
self
.
norm_type
!=
2
:
s
+=
", norm_type={norm_type}"
if
self
.
scale_grad_by_freq
is
not
False
:
s
+=
", scale_grad_by_freq={scale_grad_by_freq}"
s
+=
", mode={mode}"
return
s
.
format
(
**
self
.
__dict__
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment