"examples/pytorch/vscode:/vscode.git/clone" did not exist on "bef993076ea1369aea8f58469f15523d4cda97ba"
Commit 8d7f508a authored by Neel Kant's avatar Neel Kant
Browse files

Addressed Jared's comments

parent 03feecbc
...@@ -272,15 +272,15 @@ Loosely, they are pretraining the retriever modules, then jointly training the l ...@@ -272,15 +272,15 @@ Loosely, they are pretraining the retriever modules, then jointly training the l
### Inverse Cloze Task (ICT) Pretraining ### Inverse Cloze Task (ICT) Pretraining
1. Have a corpus in loose JSON format with the intention of creating a collection of fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block but also multiple blocks per document. 1. Have a corpus in loose JSON format with the intention of creating a collection of fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block but also multiple blocks per document.
Run `tools/preprocess_data.py` to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. For the original REALM system, we construct two datasets, one with the title of every document, and another with the body. Run `tools/preprocess_data.py` to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. For the original REALM system, we construct two datasets, one with the title of every document, and another with the body.
Refer to the following script meant to be run in an interactive session on draco: Refer to the following script
<pre> <pre>
python preprocess_data.py \ python preprocess_data.py \
--input /home/universal-lm-data.cosmos549/datasets/wikipedia/wikidump_lines.json \ --input /path/to/corpus.json \
--json-keys text title \ --json-keys text title \
--split-sentences \ --split-sentences \
--tokenizer-type BertWordPieceLowerCase \ --tokenizer-type BertWordPieceLowerCase \
--vocab-file /home/universal-lm-data.cosmos549/scratch/mshoeybi/data/albert/vocab.txt \ --vocab-file /path/to/vocab.txt \
--output-prefix wiki_indexed \ --output-prefix corpus_indexed \
--workers 5 # works well for 10 CPU cores. Scale up accordingly. --workers 5 # works well for 10 CPU cores. Scale up accordingly.
</pre> </pre>
...@@ -288,13 +288,10 @@ python preprocess_data.py \ ...@@ -288,13 +288,10 @@ python preprocess_data.py \
The samples mapping is responsible for holding all of the required metadata needed to construct the sample from one or more indexed datasets. In REALM, the samples mapping contains the start and end sentence indices, as well as the document index (to find the correct title for a body) and a unique ID for every block. The samples mapping is responsible for holding all of the required metadata needed to construct the sample from one or more indexed datasets. In REALM, the samples mapping contains the start and end sentence indices, as well as the document index (to find the correct title for a body) and a unique ID for every block.
3. Pretrain a BERT language model using `pretrain_bert.py`, with the sequence length equal to the block size in token ids. This model should be trained on the same indexed dataset that is used to supply the blocks for the information retrieval task. 3. Pretrain a BERT language model using `pretrain_bert.py`, with the sequence length equal to the block size in token ids. This model should be trained on the same indexed dataset that is used to supply the blocks for the information retrieval task.
In REALM, this is an uncased bert base model trained with the standard hyperparameters. In REALM, this is an uncased bert base model trained with the standard hyperparameters.
4. Use `pretrain_bert_ict.py` to train an `ICTBertModel` which uses two BERT-based encoders to encode queries and blocks to perform retrieval with. 4. Use `pretrain_ict.py` to train an `ICTBertModel` which uses two BERT-based encoders to encode queries and blocks to perform retrieval with.
The script below trains the ICT model from REALM on draco. It refrences a pretrained BERT model (step 3) in the `--bert-load` argument. The script below trains the ICT model from REALM. It refrences a pretrained BERT model (step 3) in the `--bert-load` argument. The batch size used in the paper is 4096, so this would need to be run with data parallel world size 32.
<pre> <pre>
EXPNAME="ict_wikipedia" python pretrain_ict.py \
CHKPT="chkpts/${EXPNAME}"
LOGDIR="logs/${EXPNAME}"
COMMAND="/home/scratch.gcf/adlr-utils/release/cluster-interface/latest/mp_launch python pretrain_bert_ict.py \
--num-layers 12 \ --num-layers 12 \
--num-attention-heads 12 \ --num-attention-heads 12 \
--hidden-size 768 \ --hidden-size 768 \
...@@ -304,13 +301,12 @@ COMMAND="/home/scratch.gcf/adlr-utils/release/cluster-interface/latest/mp_launch ...@@ -304,13 +301,12 @@ COMMAND="/home/scratch.gcf/adlr-utils/release/cluster-interface/latest/mp_launch
--ict-head-size 128 \ --ict-head-size 128 \
--train-iters 100000 \ --train-iters 100000 \
--checkpoint-activations \ --checkpoint-activations \
--bert-load /home/dcg-adlr-nkant-output.cosmos1203/chkpts/base_bert_seq256 \ --bert-load /path/to/pretrained_bert \
--load CHKPT \ --load checkpoints \
--save CHKPT \ --save checkpoints \
--data-path /home/dcg-adlr-nkant-data.cosmos1202/wiki/wikipedia_lines \ --data-path /path/to/indexed_dataset \
--titles-data-path /home/dcg-adlr-nkant-data.cosmos1202/wiki/wikipedia_lines-titles \ --titles-data-path /path/to/titles_indexed_dataset \
--vocab-file /home/universal-lm-data.cosmos549/scratch/mshoeybi/data/albert/vocab.txt \ --vocab-file /path/to/vocab.txt \
--distributed-backend nccl \
--lr 0.0001 \ --lr 0.0001 \
--num-workers 2 \ --num-workers 2 \
--lr-decay-style linear \ --lr-decay-style linear \
...@@ -319,11 +315,8 @@ COMMAND="/home/scratch.gcf/adlr-utils/release/cluster-interface/latest/mp_launch ...@@ -319,11 +315,8 @@ COMMAND="/home/scratch.gcf/adlr-utils/release/cluster-interface/latest/mp_launch
--warmup .01 \ --warmup .01 \
--save-interval 3000 \ --save-interval 3000 \
--query-in-block-prob 0.1 \ --query-in-block-prob 0.1 \
--fp16 \ --fp16
--adlr-autoresume \
--adlr-autoresume-interval 100"
submit_job --image 'http://gitlab-master.nvidia.com/adlr/megatron-lm/megatron:20.03_faiss' --mounts /home/universal-lm-data.cosmos549,/home/dcg-adlr-nkant-data.cosmos1202,/home/dcg-adlr-nkant-output.cosmos1203,/home/nkant --name "${EXPNAME}" --partition batch_32GB --gpu 8 --nodes 4 --autoresume_timer 420 -c "${COMMAND}" --logdir "${LOGDIR}"
</pre> </pre>
<a id="evaluation-and-tasks"></a> <a id="evaluation-and-tasks"></a>
......
...@@ -37,6 +37,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -37,6 +37,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_validation_args(parser) parser = _add_validation_args(parser)
parser = _add_data_args(parser) parser = _add_data_args(parser)
parser = _add_autoresume_args(parser) parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
...@@ -139,8 +140,6 @@ def _add_network_size_args(parser): ...@@ -139,8 +140,6 @@ def _add_network_size_args(parser):
' grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].') ' grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].')
group.add_argument('--hidden-size', type=int, default=None, group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.') help='Tansformer hidden size.')
group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and REALM (paper default: 128)')
group.add_argument('--num-attention-heads', type=int, default=None, group.add_argument('--num-attention-heads', type=int, default=None,
help='Number of transformer attention heads.') help='Number of transformer attention heads.')
group.add_argument('--max-position-embeddings', type=int, default=None, group.add_argument('--max-position-embeddings', type=int, default=None,
...@@ -264,10 +263,6 @@ def _add_checkpointing_args(parser): ...@@ -264,10 +263,6 @@ def _add_checkpointing_args(parser):
help='Do not save current rng state.') help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None, group.add_argument('--load', type=str, default=None,
help='Directory containing a model checkpoint.') help='Directory containing a model checkpoint.')
group.add_argument('--ict-load', type=str, default=None,
help='Directory containing an ICTBertModel checkpoint')
group.add_argument('--bert-load', type=str, default=None,
help='Directory containing an BertModel checkpoint (needed to start ICT and REALM)')
group.add_argument('--no-load-optim', action='store_true', group.add_argument('--no-load-optim', action='store_true',
help='Do not load optimizer when loading checkpoint.') help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true', group.add_argument('--no-load-rng', action='store_true',
...@@ -347,8 +342,6 @@ def _add_data_args(parser): ...@@ -347,8 +342,6 @@ def _add_data_args(parser):
group.add_argument('--data-path', type=str, default=None, group.add_argument('--data-path', type=str, default=None,
help='Path to combined dataset to split.') help='Path to combined dataset to split.')
group.add_argument('--titles-data-path', type=str, default=None,
help='Path to titles dataset used for ICT')
group.add_argument('--split', type=str, default='969, 30, 1', group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,' help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split ' ' validation, and test split. For example the split '
...@@ -384,10 +377,6 @@ def _add_data_args(parser): ...@@ -384,10 +377,6 @@ def _add_data_args(parser):
'end-of-document token.') 'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true', group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens.') help='Mask loss for the end of document tokens.')
group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for ICT dataset')
group.add_argument('--ict-one-sent', action='store_true',
help='Whether to use one sentence documents in ICT')
return parser return parser
...@@ -402,3 +391,28 @@ def _add_autoresume_args(parser): ...@@ -402,3 +391,28 @@ def _add_autoresume_args(parser):
'termination signal') 'termination signal')
return parser return parser
def _add_realm_args(parser):
group = parser.add_argument_group(title='realm')
# network size
group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and REALM (paper default: 128)')
# checkpointing
group.add_argument('--ict-load', type=str, default=None,
help='Directory containing an ICTBertModel checkpoint')
group.add_argument('--bert-load', type=str, default=None,
help='Directory containing an BertModel checkpoint (needed to start ICT and REALM)')
# data
group.add_argument('--titles-data-path', type=str, default=None,
help='Path to titles dataset used for ICT')
group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for ICT dataset')
group.add_argument('--ict-one-sent', action='store_true',
help='Whether to use one sentence documents in ICT')
return parser
...@@ -23,8 +23,9 @@ import numpy as np ...@@ -23,8 +23,9 @@ import numpy as np
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import mpu, print_rank_0 from megatron import mpu
from megatron import get_args from megatron import get_args
from megatron import print_rank_0
def check_checkpoint_args(checkpoint_args): def check_checkpoint_args(checkpoint_args):
......
...@@ -22,7 +22,8 @@ import numpy as np ...@@ -22,7 +22,8 @@ import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import get_tokenizer, get_args, print_rank_0 from megatron import get_tokenizer, get_args
from megatron import print_rank_0
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import get_a_and_b_segments from megatron.data.dataset_utils import get_a_and_b_segments
from megatron.data.dataset_utils import truncate_segments from megatron.data.dataset_utils import truncate_segments
......
...@@ -399,7 +399,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -399,7 +399,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def build_dataset(index, name): def build_dataset(index, name):
from megatron.data.bert_dataset import BertDataset from megatron.data.bert_dataset import BertDataset
from megatron.data.realm_dataset import ICTDataset from megatron.data.ict_dataset import ICTDataset
dataset = None dataset = None
if splits[index + 1] > splits[index]: if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later. # Get the pointer to the original doc-idx so we can set it later.
......
...@@ -26,12 +26,46 @@ from megatron.model.utils import openai_gelu ...@@ -26,12 +26,46 @@ from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.model.utils import bert_attention_mask_func
from megatron.model.utils import bert_extended_attention_mask
from megatron.model.utils import bert_position_ids
from megatron.module import MegatronModule from megatron.module import MegatronModule
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores = attention_scores + attention_mask
return attention_scores
def bert_extended_attention_mask(attention_mask, dtype):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Since attention_mask is 1.0 for positions we want to attend and 0.0
# for masked positions, this operation will create a tensor which is
# 0.0 for positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# fp16 compatibility
extended_attention_mask = extended_attention_mask.to(dtype=dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long,
device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
class BertLMHead(MegatronModule): class BertLMHead(MegatronModule):
"""Masked LM head for Bert """Masked LM head for Bert
...@@ -171,5 +205,3 @@ class BertModel(MegatronModule): ...@@ -171,5 +205,3 @@ class BertModel(MegatronModule):
if self.add_binary_head: if self.add_binary_head:
self.binary_head.load_state_dict( self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict) state_dict[self._binary_head_key], strict=strict)
...@@ -18,9 +18,7 @@ ...@@ -18,9 +18,7 @@
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0
from megatron.model.bert_model import bert_attention_mask_func from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
......
...@@ -18,9 +18,7 @@ ...@@ -18,9 +18,7 @@
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0
from megatron.model.bert_model import bert_attention_mask_func from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
......
...@@ -10,9 +10,7 @@ from megatron.model.utils import get_linear_layer ...@@ -10,9 +10,7 @@ from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.model.utils import bert_attention_mask_func from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.utils import bert_extended_attention_mask
from megatron.model.utils import bert_position_ids
class ICTBertModel(MegatronModule): class ICTBertModel(MegatronModule):
......
...@@ -78,42 +78,3 @@ def get_params_for_weight_decay_optimization(module): ...@@ -78,42 +78,3 @@ def get_params_for_weight_decay_optimization(module):
if p is not None and n == 'bias']) if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params return weight_decay_params, no_weight_decay_params
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores = attention_scores + attention_mask
return attention_scores
def bert_extended_attention_mask(attention_mask, dtype):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Since attention_mask is 1.0 for positions we want to attend and 0.0
# for masked positions, this operation will create a tensor which is
# 0.0 for positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# fp16 compatibility
extended_attention_mask = extended_attention_mask.to(dtype=dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long,
device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
...@@ -22,10 +22,11 @@ import torch ...@@ -22,10 +22,11 @@ import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import get_timers from megatron import get_timers
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import mpu from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Module from megatron.fp16 import FP16_Module
...@@ -217,7 +218,11 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -217,7 +218,11 @@ def setup_model_and_optimizer(model_provider_func):
else: else:
args.iteration = 0 args.iteration = 0
unwrapped_model = model.module.module # get model without FP16 and/or TorchDDP wrappers
unwrapped_model = model
while hasattr(unwrapped_model, 'module'):
unwrapped_model = unwrapped_model.module
if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'): if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'):
print("Initializing ICT from pretrained BERT model", flush=True) print("Initializing ICT from pretrained BERT model", flush=True)
unwrapped_model.init_state_dict_from_bert() unwrapped_model.init_state_dict_from_bert()
......
...@@ -19,7 +19,8 @@ import sys ...@@ -19,7 +19,8 @@ import sys
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import mpu from megatron import mpu
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
......
...@@ -19,7 +19,8 @@ import torch ...@@ -19,7 +19,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
......
...@@ -20,7 +20,8 @@ import time ...@@ -20,7 +20,8 @@ import time
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import mpu from megatron import mpu
from tasks.finetune_utils import build_data_loader from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch from tasks.finetune_utils import process_batch
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
"""GLUE finetuning/evaluation.""" """GLUE finetuning/evaluation."""
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.model.classification import Classification from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider from tasks.eval_utils import accuracy_func_provider
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
"""Race.""" """Race."""
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.model.multiple_choice import MultipleChoice from megatron.model.multiple_choice import MultipleChoice
from tasks.eval_utils import accuracy_func_provider from tasks.eval_utils import accuracy_func_provider
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment