Commit ac3db159 authored by rprenger's avatar rprenger
Browse files

Got 530 Billion parameter model working!

parent 453414da
...@@ -80,7 +80,10 @@ class MegatronGenerate(Resource): ...@@ -80,7 +80,10 @@ class MegatronGenerate(Resource):
def put(self): def put(self):
args = get_args() args = get_args()
sentences = request.get_json()["sentences"] sentences = request.get_json()["sentences"]
max_len = args.seq_length if len(sentences) > 128:
return "Maximum number of sentences is 128", 400
max_len = 64 # Choosing hopefully sane default. Full sequence is slow
if "max_len" in request.get_json(): if "max_len" in request.get_json():
input_max_len = request.get_json()["max_len"] input_max_len = request.get_json()["max_len"]
if input_max_len < args.seq_length: if input_max_len < args.seq_length:
......
...@@ -21,6 +21,7 @@ import time ...@@ -21,6 +21,7 @@ import time
import numpy as np import numpy as np
import torch import torch
from datetime import timedelta
from megatron import fused_kernels from megatron import fused_kernels
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
...@@ -183,7 +184,8 @@ def _initialize_distributed(): ...@@ -183,7 +184,8 @@ def _initialize_distributed():
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=args.distributed_backend, backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank, world_size=args.world_size, rank=args.rank,
init_method=init_method) init_method=init_method,
timeout=timedelta(days=7))
# Set the tensor model-parallel, pipeline model-parallel, and # Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators. # data-parallel communicators.
......
...@@ -133,6 +133,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, ...@@ -133,6 +133,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
args = get_args() args = get_args()
orig_seq_length = args.seq_length orig_seq_length = args.seq_length
args.seq_length = tokens.shape[1] args.seq_length = tokens.shape[1]
args.micro_batch_size = tokens.shape[0]
input_tensor = recv_forward() input_tensor = recv_forward()
......
#!/bin/bash
CHECKPOINT="/home/universal-lm-data.cosmos549/scratch/jcasper/gpt3-530b-megatron_tp16_pp3"
DATA_PATH="/home/universal-lm-data.cosmos549/scratch/mshoeybi/data/gpt2"
VOCAB_FILE="${DATA_PATH}/bpe/gpt2-vocab.json"
MERGE_FILE="${DATA_PATH}/bpe/gpt2-merges.txt"
RUN_CMD=(
python -m cProfile -s cumtime tools/run_api_server.py
--tensor-model-parallel-size 16
--pipeline-model-parallel-size 3
--num-layers 105
--hidden-size 20480
--load ${CHECKPOINT}
--num-attention-heads 128
--max-position-embeddings 2048
--tokenizer-type GPT2BPETokenizer
--fp16
--micro-batch-size 1
--seq-length 2048
--out-seq-length 2048
--temperature 1.0
--vocab-file $VOCAB_FILE
--merge-file $MERGE_FILE
--top_p 0.9
--seed 42
)
submit_job --nodes 3 --gpu 16 --reservation adlr-530b --partition batch_UN_dgx2_singlenode --mounts /home/universal-lm-data.cosmos549,/home/dcg-adlr-rprenger-source.cosmos352,/home/dcg-adlr-sgodil-data.cosmos233,/home/dcg-adlr-rprenger-output.cosmos349,/home/dcg-adlr-mchrzanowski-chidesign-data --image gitlab-master.nvidia.com/adlr/rprenger/megatron:latest --skip_ib_check --tasks_per_node 16 -c "${RUN_CMD[*]}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment