Commit a96ad63f authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

fixed alignment

parent 651cc107
...@@ -644,14 +644,18 @@ def _add_biencoder_args(parser): ...@@ -644,14 +644,18 @@ def _add_biencoder_args(parser):
# network size # network size
group.add_argument('--ict-head-size', type=int, default=None, group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and REALM (paper default: 128)') help='Size of block embeddings to be used in ICT and '
'REALM (paper default: 128)')
group.add_argument('--projection-dim', type=int, default=0, group.add_argument('--projection-dim', type=int, default=0,
help='Size of projection head used in biencoder (paper default: 128)') help='Size of projection head used in biencoder (paper'
' default: 128)')
group.add_argument('--shared-query-context-model', action='store_true', group.add_argument('--shared-query-context-model', action='store_true',
help='Whether to share the parameters of the query and context models or not') help='Whether to share the parameters of the query '
'and context models or not')
group.add_argument('--pool-type', type=str, default='cls-token', group.add_argument('--pool-type', type=str, default='cls-token',
choices=['avg', 'cls-token', 'max'], choices=['avg', 'cls-token', 'max'],
help='different options are: avg | cls-token | max, default=cls-token') help='different options are: avg | cls-token | max, '
'default=cls-token')
# checkpointing # checkpointing
group.add_argument('--ict-load', type=str, default=None, group.add_argument('--ict-load', type=str, default=None,
...@@ -670,10 +674,12 @@ def _add_biencoder_args(parser): ...@@ -670,10 +674,12 @@ def _add_biencoder_args(parser):
help='Whether to use one sentence documents in ICT') help='Whether to use one sentence documents in ICT')
# training # training
group.add_argument('--report-topk-accuracies', nargs='+', type=int, default=[], group.add_argument('--report-topk-accuracies', nargs='+', type=int,
help="Which top-k accuracies to report (e.g. '1 5 20')") default=[], help="Which top-k accuracies to report '
'(e.g. '1 5 20')")
group.add_argument('--retriever-score-scaling', action='store_true', group.add_argument('--retriever-score-scaling', action='store_true',
help="Whether to scale retriever scores by inverse square root of hidden size") help="Whether to scale retriever scores by inverse '
'square root of hidden size")
# faiss index # faiss index
group.add_argument('--faiss-use-gpu', action='store_true', group.add_argument('--faiss-use-gpu', action='store_true',
......
...@@ -34,9 +34,11 @@ from megatron.data.biencoder_dataset_utils import get_ict_batch ...@@ -34,9 +34,11 @@ from megatron.data.biencoder_dataset_utils import get_ict_batch
def pretrain_ict_model_provider(): def pretrain_ict_model_provider():
args = get_args() args = get_args()
model = biencoder_model_provider(only_context_model=False, model = biencoder_model_provider(
only_query_model=False, only_context_model=False,
shared_query_context_model=args.shared_query_context_model) only_query_model=False,
shared_query_context_model=args.shared_query_context_model
)
return model return model
def get_group_world_size_rank(): def get_group_world_size_rank():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment