Commit be136206 authored by rprenger's avatar rprenger
Browse files

Cleaning up Switch Transformer code

parent 1ec6e720
...@@ -343,7 +343,8 @@ def _add_network_size_args(parser): ...@@ -343,7 +343,8 @@ def _add_network_size_args(parser):
group.add_argument('--bert-no-binary-head', action='store_false', group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.', help='Disable BERT binary head.',
dest='bert_binary_head') dest='bert_binary_head')
group.add_argument('--num-experts', type=int, default=None,
help='Number of Experts in Switch Transformer (None means no Switch)')
return parser return parser
......
...@@ -99,12 +99,12 @@ class SwitchMLP(MegatronModule): ...@@ -99,12 +99,12 @@ class SwitchMLP(MegatronModule):
""" """
Routes input to one of N MLP "experts" Routes input to one of N MLP "experts"
""" """
def __init__(self, init_method, output_layer_init_method, num_experts): def __init__(self, init_method, output_layer_init_method):
super(SwitchMLP, self).__init__() super(SwitchMLP, self).__init__()
args = get_args() args = get_args()
self.router = torch.nn.Linear(args.hidden_size, num_experts) self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
self.experts = torch.nn.ModuleList() self.experts = torch.nn.ModuleList()
for i in range(num_experts): for i in range(args.num_experts):
self.experts.append(ParallelMLP(init_method, output_layer_init_method)) self.experts.append(ParallelMLP(init_method, output_layer_init_method))
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -113,16 +113,20 @@ class SwitchMLP(MegatronModule): ...@@ -113,16 +113,20 @@ class SwitchMLP(MegatronModule):
s = hidden_states.size(1) s = hidden_states.size(1)
h = hidden_states.size(2) h = hidden_states.size(2)
route = self.router(hidden_states) route = self.router(hidden_states)
route = torch.nn.functional.softmax(route,dim=2) route = torch.nn.functional.softmax(route, dim=2)
max_prob, max_ind = torch.max(route, dim=2) max_prob, max_ind = torch.max(route, dim=2)
max_prob = torch.unsqueeze(max_prob, 2) max_prob = torch.unsqueeze(max_prob, 2)
# TODO (rprenger) TODO this could be made easier to read
# Converting [b, s, h] to [b*s, h].
# Each vector could be routed differently
hidden_states = hidden_states.permute(2,0,1).view(hidden_states.size(2), -1).permute(1,0).unsqueeze(1) hidden_states = hidden_states.permute(2,0,1).view(hidden_states.size(2), -1).permute(1,0).unsqueeze(1)
max_prob = max_prob.permute(2,0,1).view(max_prob.size(2), -1).permute(1,0).unsqueeze(1) max_prob = max_prob.permute(2,0,1).view(max_prob.size(2), -1).permute(1,0).unsqueeze(1)
max_ind = max_ind.view(-1) max_ind = max_ind.view(-1)
output_total = torch.empty_like(hidden_states) output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states) output_bias_total = torch.empty_like(hidden_states)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for expert_num, expert in enumerate(self.experts): for expert_num, expert in enumerate(self.experts):
ind = (max_ind==expert_num).nonzero().unsqueeze(2).repeat(1,1, h) ind = (max_ind==expert_num).nonzero().unsqueeze(2).repeat(1,1, h)
hidden = torch.gather(hidden_states, 0, ind) hidden = torch.gather(hidden_states, 0, ind)
...@@ -498,7 +502,10 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -498,7 +502,10 @@ class ParallelTransformerLayer(MegatronModule):
no_persist_layer_norm=args.no_persist_layer_norm) no_persist_layer_norm=args.no_persist_layer_norm)
# MLP # MLP
self.mlp = SwitchMLP(init_method, output_layer_init_method, ${NUMEXPERTS}) if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
......
#!/bin/bash
#SBATCH -p luna -A adlr -t 4:00:00 --nodes=16 --exclusive --mem=0 --overcommit --ntasks-per-node=8 --dependency=singleton --job-name=adlr-nlp-largelm:switch_1.3b_RUNVAR_expert
NAME="gpt3-1.3b_switch_RUNVAR_expert"
DIR=`pwd`
DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
mkdir -p $DIR/logs
CHECKPOINT_DIR="/lustre/fsw/adlr/adlr-nlp/rprenger/switch/${NAME}"
TENSORBOARD_DIR="${CHECKPOINT_DIR}/tensorboard"
mkdir -p ${TENSORBOARD_DIR}
# Get the data blend
. /lustre/fsw/adlr/adlr-nlp/data/pile-cc1-cc2-shuf/gpt3_blend.sh
BPE_DIR="/lustre/fsw/adlr/adlr-nlp/data/pile-cc1-cc2-shuf/bpe"
options=" \
--exit-duration-in-mins 230 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 2048 \
--num-attention-heads 32 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 4 \
--global-batch-size 512 \
--rampup-batch-size 32 32 2929688 \
--train-samples 192000000 \
--lr-decay-samples 166400000 \
--lr-warmup-samples 244141 \
--lr 2.0e-4 \
--min-lr 2.0e-5 \
--lr-decay-style cosine \
--log-interval 100 \
--eval-iters 50 \
--eval-interval 2000 \
--data-path ${DATA_BLEND} \
--vocab-file ${BPE_DIR}/gpt2-vocab.json \
--merge-file ${BPE_DIR}/gpt2-merges.txt \
--save-interval 10000 \
--save ${CHECKPOINT_DIR} \
--load ${CHECKPOINT_DIR} \
--split 98,2,0 \
--clip-grad 1.0 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.014 \
--log-params-norm \
--log-num-zeros-in-grad \
--fp16 \
--DDP-impl torch \
--tensorboard-dir ${TENSORBOARD_DIR} \
--checkpoint-activations "
run_cmd="cd $DIR && python pretrain_gpt.py ${options}"
srun -l \
--container-image "/lustre/fsw/adlr/adlr-nlp/images/pytorch+bf16_nccl_fusion.sqsh" \
--container-mounts "/lustre/fsw/adlr:/lustre/fsw/adlr,/home/rprenger/workspace:/home/rprenger/workspace" \
--output=$DIR/logs/%x_%j_$DATETIME.log sh -c "${run_cmd}"
set +x
#!/bin/bash
#SBATCH -p luna -A adlr -t 4:00:00 --nodes=4 --exclusive --mem=0 --overcommit --ntasks-per-node=8 --dependency=singleton --job-name=adlr-nlp-largelm:switch_126m_RUNVAR_expert
NAME="gpt3-126m_switch_RUNVAR_expert"
DIR=`pwd`
DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
mkdir -p $DIR/logs
CHECKPOINT_DIR="/lustre/fsw/adlr/adlr-nlp/rprenger/switch/${NAME}"
TENSORBOARD_DIR="${CHECKPOINT_DIR}/tensorboard"
mkdir -p ${TENSORBOARD_DIR}
# Get the data blend
. /lustre/fsw/adlr/adlr-nlp/data/pile-cc1-cc2-shuf/gpt3_blend.sh
BPE_DIR="/lustre/fsw/adlr/adlr-nlp/data/pile-cc1-cc2-shuf/bpe"
options=" \
--exit-duration-in-mins 230 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 8 \
--global-batch-size 256 \
--rampup-batch-size 32 32 1953125 \
--train-samples 192000000 \
--lr-decay-samples 166400000 \
--lr-warmup-samples 162761 \
--lr 6.0e-4 \
--min-lr 6.0e-5 \
--lr-decay-style cosine \
--log-interval 100 \
--eval-iters 50 \
--eval-interval 2000 \
--data-path ${DATA_BLEND} \
--vocab-file ${BPE_DIR}/gpt2-vocab.json \
--merge-file ${BPE_DIR}/gpt2-merges.txt \
--save-interval 10000 \
--save ${CHECKPOINT_DIR} \
--load ${CHECKPOINT_DIR} \
--split 98,2,0 \
--clip-grad 1.0 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.023 \
--log-params-norm \
--log-num-zeros-in-grad \
--fp16 \
--DDP-impl torch \
--tensorboard-dir ${TENSORBOARD_DIR} "
run_cmd="cd $DIR && python pretrain_gpt.py ${options}"
srun -l \
--container-image "/lustre/fsw/adlr/adlr-nlp/images/pytorch+bf16_nccl_fusion.sqsh" \
--container-mounts "/lustre/fsw/adlr:/lustre/fsw/adlr,/home/rprenger/workspace:/home/rprenger/workspace" \
--output=$DIR/logs/%x_%j_$DATETIME.log sh -c "${run_cmd}"
set +x
#!/bin/bash
#SBATCH -p luna -A adlr -t 4:00:00 --nodes=8 --exclusive --mem=0 --overcommit --ntasks-per-node=8 --dependency=singleton --job-name=adlr-nlp-largelm:switch_357m_RUNVAR_expert
NAME="gpt3-357m_switch_RUNVAR_expert"
DIR=`pwd`
DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
mkdir -p $DIR/logs
CHECKPOINT_DIR="/lustre/fsw/adlr/adlr-nlp/rprenger/switch/${NAME}"
TENSORBOARD_DIR="${CHECKPOINT_DIR}/tensorboard"
mkdir -p ${TENSORBOARD_DIR}
# Get the data blend
. /lustre/fsw/adlr/adlr-nlp/data/pile-cc1-cc2-shuf/gpt3_blend.sh
BPE_DIR="/lustre/fsw/adlr/adlr-nlp/data/pile-cc1-cc2-shuf/bpe"
options=" \
--exit-duration-in-mins 230 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 4 \
--global-batch-size 256 \
--rampup-batch-size 32 32 1953125 \
--train-samples 192000000 \
--lr-decay-samples 166400000 \
--lr-warmup-samples 162761 \
--lr 3.0e-4 \
--min-lr 3.0e-5 \
--lr-decay-style cosine \
--log-interval 100 \
--eval-iters 50 \
--eval-interval 2000 \
--data-path ${DATA_BLEND} \
--vocab-file ${BPE_DIR}/gpt2-vocab.json \
--merge-file ${BPE_DIR}/gpt2-merges.txt \
--save-interval 10000 \
--save ${CHECKPOINT_DIR} \
--load ${CHECKPOINT_DIR} \
--split 98,2,0 \
--clip-grad 1.0 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.02 \
--log-params-norm \
--log-num-zeros-in-grad \
--fp16 \
--DDP-impl torch \
--tensorboard-dir ${TENSORBOARD_DIR} \
--checkpoint-activations "
run_cmd="cd $DIR && python pretrain_gpt.py ${options}"
srun -l \
--container-image "/lustre/fsw/adlr/adlr-nlp/images/pytorch+bf16_nccl_fusion.sqsh" \
--container-mounts "/lustre/fsw/adlr:/lustre/fsw/adlr,/home/rprenger/workspace:/home/rprenger/workspace" \
--output=$DIR/logs/%x_%j_$DATETIME.log sh -c "${run_cmd}"
set +x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment