Commit 76640072 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

[Ranking] Updating training scripts for ranking model.

PiperOrigin-RevId: 385629372
parent 703688aa
...@@ -111,6 +111,7 @@ export TPU_NAME=my-dlrm-tpu ...@@ -111,6 +111,7 @@ export TPU_NAME=my-dlrm-tpu
export EXPERIMENT_NAME=my_experiment_name export EXPERIMENT_NAME=my_experiment_name
export BUCKET_NAME="gs://my_dlrm_bucket" export BUCKET_NAME="gs://my_dlrm_bucket"
export DATA_DIR="${BUCKET_NAME}/data" export DATA_DIR="${BUCKET_NAME}/data"
export EMBEDDING_DIM=32
python3 models/official/recommendation/ranking/train.py --mode=train_and_eval \ python3 models/official/recommendation/ranking/train.py --mode=train_and_eval \
--model_dir=${BUCKET_NAME}/model_dirs/${EXPERIMENT_NAME} --params_override=" --model_dir=${BUCKET_NAME}/model_dirs/${EXPERIMENT_NAME} --params_override="
...@@ -126,8 +127,8 @@ task: ...@@ -126,8 +127,8 @@ task:
global_batch_size: 16384 global_batch_size: 16384
model: model:
num_dense_features: 13 num_dense_features: 13
bottom_mlp: [512,256,128] bottom_mlp: [512,256,${EMBEDDING_DIM}]
embedding_dim: 128 embedding_dim: ${EMBEDDING_DIM}
top_mlp: [1024,1024,512,256,1] top_mlp: [1024,1024,512,256,1]
interaction: 'dot' interaction: 'dot'
vocab_sizes: [39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63, vocab_sizes: [39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63,
...@@ -135,8 +136,8 @@ task: ...@@ -135,8 +136,8 @@ task:
39979771, 25641295, 39664984, 585935, 12972, 108, 36] 39979771, 25641295, 39664984, 585935, 12972, 108, 36]
trainer: trainer:
use_orbit: true use_orbit: true
validation_interval: 90000 validation_interval: 85352
checkpoint_interval: 100000 checkpoint_interval: 85352
validation_steps: 5440 validation_steps: 5440
train_steps: 256054 train_steps: 256054
steps_per_loop: 1000 steps_per_loop: 1000
...@@ -154,6 +155,8 @@ Training on GPUs are similar to TPU training. Only distribution strategy needs ...@@ -154,6 +155,8 @@ Training on GPUs are similar to TPU training. Only distribution strategy needs
to be updated and number of GPUs provided (for 4 GPUs): to be updated and number of GPUs provided (for 4 GPUs):
```shell ```shell
export EMBEDDING_DIM=8
python3 official/recommendation/ranking/train.py --mode=train_and_eval \ python3 official/recommendation/ranking/train.py --mode=train_and_eval \
--model_dir=${BUCKET_NAME}/model_dirs/${EXPERIMENT_NAME} --params_override=" --model_dir=${BUCKET_NAME}/model_dirs/${EXPERIMENT_NAME} --params_override="
runtime: runtime:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment