Commit d0d55509 authored by hepj987's avatar hepj987
Browse files

修复tool

parent 6bd15ea7
Pipeline #579 failed with stage
...@@ -178,16 +178,6 @@ mpirun -np 1 run-inf.sh ...@@ -178,16 +178,6 @@ mpirun -np 1 run-inf.sh
--num-samples 生成样本个数 --num-samples 生成样本个数
``` ```
## 应用场景
### 算法类别
`文本生成`
### 热点应用行业
`互联网`
## result ## result
16B模型训练loss: 16B模型训练loss:
...@@ -208,6 +198,16 @@ mpirun -np 1 run-inf.sh ...@@ -208,6 +198,16 @@ mpirun -np 1 run-inf.sh
![image-20230524143830580](image-gpt-loss2.png) ![image-20230524143830580](image-gpt-loss2.png)
## 应用场景
### 算法类别
`文本生成`
### 热点应用行业
`互联网`
## 源码仓库及问题反馈 ## 源码仓库及问题反馈
https://developer.hpccube.com/codes/modelzoo/gpt2-pytorch/ https://developer.hpccube.com/codes/modelzoo/gpt2-pytorch/
......
...@@ -5,6 +5,6 @@ modelName=gpt2_pytorch ...@@ -5,6 +5,6 @@ modelName=gpt2_pytorch
# 模型描述 # 模型描述
modelDescription=基于Pytorch训练框架的gpt2模型 modelDescription=基于Pytorch训练框架的gpt2模型
# 应用场景 # 应用场景
appScenario=训练,推理,train,inference,nlp,智能聊天助手 appScenario=训练,推理,文本生成,互联网
# 框架类型 # 框架类型
frameType=Pytorch,Deepspeed frameType=Pytorch,Deepspeed
...@@ -7,6 +7,6 @@ np=$(($np*8)) ...@@ -7,6 +7,6 @@ np=$(($np*8))
nodename=$(cat $hostfile |sed -n "1p") nodename=$(cat $hostfile |sed -n "1p")
dist_url=`echo $nodename | awk '{print $1}'` dist_url=`echo $nodename | awk '{print $1}'`
which mpirun which mpirun
mpirun -np $np --allow-run-as-root --hostfile hostfile --bind-to none --mca btl_tcp_if_include $dist_url single-16B-fp16.sh mpirun -np $np --allow-run-as-root --hostfile $hostfile --bind-to none --mca btl_tcp_if_include $dist_url single-16B-fp16.sh $dist_url
echo "END TIME: $(date)" echo "END TIME: $(date)"
...@@ -7,6 +7,6 @@ np=$(($np*8)) ...@@ -7,6 +7,6 @@ np=$(($np*8))
nodename=$(cat $hostfile |sed -n "1p") nodename=$(cat $hostfile |sed -n "1p")
dist_url=`echo $nodename | awk '{print $1}'` dist_url=`echo $nodename | awk '{print $1}'`
which mpirun which mpirun
mpirun -np $np --allow-run-as-root --hostfile hostfile --bind-to none --mca btl_tcp_if_include $dist_url single-16B.sh mpirun -np $np --allow-run-as-root --hostfile $hostfile --bind-to none --mca btl_tcp_if_include $dist_url single-16B.sh $dist_url
echo "END TIME: $(date)" echo "END TIME: $(date)"
...@@ -53,7 +53,7 @@ GPT_ARGS=" \ ...@@ -53,7 +53,7 @@ GPT_ARGS=" \
--max-position-embeddings $SEQ_LEN \ --max-position-embeddings $SEQ_LEN \
--micro-batch-size $MICRO_BATCH_SIZE \ --micro-batch-size $MICRO_BATCH_SIZE \
--global-batch-size $GLOBAL_BATCH_SIZE \ --global-batch-size $GLOBAL_BATCH_SIZE \
--train_iters 7000 \ --train-iters 7000 \
--loss-scale 12 \ --loss-scale 12 \
--vocab-file gpt2-vocab.json \ --vocab-file gpt2-vocab.json \
--merge-file gpt2-merges.txt \ --merge-file gpt2-merges.txt \
......
...@@ -53,7 +53,7 @@ GPT_ARGS=" \ ...@@ -53,7 +53,7 @@ GPT_ARGS=" \
--max-position-embeddings $SEQ_LEN \ --max-position-embeddings $SEQ_LEN \
--micro-batch-size $MICRO_BATCH_SIZE \ --micro-batch-size $MICRO_BATCH_SIZE \
--global-batch-size $GLOBAL_BATCH_SIZE \ --global-batch-size $GLOBAL_BATCH_SIZE \
--train_iters 7000 \ --train-iters 7000 \
--loss-scale 12 \ --loss-scale 12 \
--vocab-file gpt2-vocab.json \ --vocab-file gpt2-vocab.json \
--merge-file gpt2-merges.txt \ --merge-file gpt2-merges.txt \
......
...@@ -4,8 +4,10 @@ import argparse ...@@ -4,8 +4,10 @@ import argparse
import os import os
import torch import torch
from collections import OrderedDict from collections import OrderedDict
from .deepspeed_checkpoint import ARGS_KEY, DeepSpeedCheckpoint from deepspeed.checkpoint.deepspeed_checkpoint import (
ARGS_KEY,
DeepSpeedCheckpoint,
)
MODEL_KEY = 'model' MODEL_KEY = 'model'
ARGS_KEY = 'args' ARGS_KEY = 'args'
LANGUGAGE_MODEL_KEY = 'language_model' LANGUGAGE_MODEL_KEY = 'language_model'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment