Commit 4018d92c authored by rprenger's avatar rprenger
Browse files

Faster Switch code

parent d4169684
......@@ -95,6 +95,49 @@ class ParallelMLP(MegatronModule):
return output, output_bias
class SwitchMLP(MegatronModule):
"""
Routes input to one of N MLP "experts"
"""
def __init__(self, init_method, output_layer_init_method, num_experts):
super(SwitchMLP, self).__init__()
args = get_args()
self.router = torch.nn.Linear(args.hidden_size, num_experts)
self.experts = torch.nn.ModuleList()
for i in range(num_experts):
self.experts.append(ParallelMLP(init_method, output_layer_init_method))
def forward(self, hidden_states):
# hidden_states: [b, s, h]
b = hidden_states.size(0)
s = hidden_states.size(1)
h = hidden_states.size(2)
route = self.router(hidden_states)
route = torch.nn.functional.softmax(route,dim=2)
max_prob, max_ind = torch.max(route, dim=2)
max_prob = torch.unsqueeze(max_prob, 2)
hidden_states = hidden_states.permute(2,0,1).view(hidden_states.size(2), -1).permute(1,0).unsqueeze(1)
max_prob = max_prob.permute(2,0,1).view(max_prob.size(2), -1).permute(1,0).unsqueeze(1)
max_ind = max_ind.view(-1)
output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states)
for expert_num, expert in enumerate(self.experts):
ind = (max_ind==expert_num).nonzero().unsqueeze(2).repeat(1,1, h)
hidden = torch.gather(hidden_states, 0, ind)
output, output_bias = expert(hidden)
output_bias = output_bias.expand_as(output)
output_total.scatter_(0, ind, output)
output_bias_total.scatter_(0, ind, output_bias)
output_total = output_total*max_prob
output_bias_total = output_bias_total*max_prob
output_total = output_total.permute(2,0,1).view(h, b, s).permute(1,2,0)
output_bias_total = output_bias_total.permute(2,0,1).view(h, b, s).permute(1,2,0)
return output_total, output_bias_total
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
......@@ -455,8 +498,7 @@ class ParallelTransformerLayer(MegatronModule):
no_persist_layer_norm=args.no_persist_layer_norm)
# MLP
self.mlp = ParallelMLP(init_method,
output_layer_init_method)
self.mlp = SwitchMLP(init_method, output_layer_init_method, ${NUMEXPERTS})
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
......@@ -531,7 +573,7 @@ class ParallelTransformerLayer(MegatronModule):
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
......
#!/bin/bash
#SBATCH -A adlr -J adlr-nlp-largelm:switch_RUNVAR_expert -p luna -t 4:00:00 --nodes=1 --exclusive --mem=0 --overcommit --ntasks-per-node=8 --dependency=singleton
NAME="gpt3-357m_switch_RUNVAR_expert"
DIR=`pwd`
DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
mkdir -p $DIR/logs
CHECKPOINT_DIR="/lustre/fsw/adlr/adlr-nlp/rprenger/switch/${NAME}"
TENSORBOARD_DIR="${CHECKPOINT_DIR}/tensorboard"
mkdir -p ${TENSORBOARD_DIR}
# Get the data blend
. /lustre/fsw/adlr/adlr-nlp/data/pile-cc1-cc2-shuf/gpt3_blend.sh
BPE_DIR="/lustre/fsw/adlr/adlr-nlp/data/pile-cc1-cc2-shuf/bpe"
options=" \
--exit-duration-in-mins 230 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 4 \
--global-batch-size 256 \
--train-samples 192000000 \
--lr-decay-samples 166400000 \
--lr-warmup-samples 162761 \
--lr 3.0e-4 \
--min-lr 3.0e-5 \
--lr-decay-style cosine \
--log-interval 100 \
--eval-iters 50 \
--eval-interval 2000 \
--data-path ${DATA_BLEND} \
--vocab-file ${BPE_DIR}/gpt2-vocab.json \
--merge-file ${BPE_DIR}/gpt2-merges.txt \
--save-interval 10000 \
--save ${CHECKPOINT_DIR} \
--load ${CHECKPOINT_DIR} \
--split 98,2,0 \
--clip-grad 1.0 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.02 \
--log-params-norm \
--log-num-zeros-in-grad \
--fp16 \
--DDP-impl torch \
--tensorboard-dir ${TENSORBOARD_DIR} \
--checkpoint-activations "
run_cmd="cd $DIR && python pretrain_gpt.py ${options}"
srun -l \
--container-image "/lustre/fsw/adlr/adlr-nlp/images/pytorch+bf16_nccl_fusion.sqsh" \
--container-mounts "/lustre/fsw/adlr:/lustre/fsw/adlr,/home/rprenger/workspace:/home/rprenger/workspace" \
--ntasks-per-node 8 \
--output=$DIR/logs/%x_%j_$DATETIME.log sh -c "${run_cmd}"
set +x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment