"docs/source/en/community.mdx" did not exist on "0311ba21534e3f21ac9b9327009e669e34c9b367"
split_bin.py 698 Bytes
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch,transformers
import sys, os
sys.path.append(
    os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))

from transformers import AutoModelForCausalLM,AutoConfig, LlamaTokenizer,AutoTokenizer
from yuan_moe_hf_model import YuanForCausalLM

import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--input-path', type=str, help='Path to the input file')
parser.add_argument('--output-path', type=str, help='Path to the output file')

args = parser.parse_args()

model = YuanForCausalLM.from_pretrained(args.input_path,device_map='auto',torch_dtype=torch.bfloat16,trust_remote_code=True)

model.save_pretrained(args.output_path,max_shard_size='10GB')