"torchvision/models/vscode:/vscode.git/clone" did not exist on "1b7c0f54e2913e159394a19ac5e50daa69c142c7"
Commit e5ca7e62 authored by hepj987's avatar hepj987
Browse files

初始化仓库

parents
Pipeline #437 failed with stages
in 0 seconds
export HIP_LAUNCH_BLOCKING=1
mpirun --allow-run-as-root -np 4 single_pre1_4.sh
#!/bin/bash
mpirun --allow-run-as-root -np 1 single_pre1_1_fp16.sh
#!/bin/bash
mpirun --allow-run-as-root -np 1 single_pre2_1.sh
#!/bin/bash
export HIP_LAUNCH_BLOCKING=1
mpirun --allow-run-as-root -np 4 single_pre2_4.sh
#!/bin/bash
export HIP_LAUNCH_BLOCKING=1
mpirun --allow-run-as-root -np 4 single_pre2_4_fp16.sh
#!/bin/bash
mpirun --allow-run-as-root -np 1 single_pre2_1_fp16.sh
#!/bin/bash
mpirun --allow-run-as-root -np 1 single_squad.sh
#!/bin/bash
#export LD_LIBRARY_PATH=/public/home/hepj/job_env/apps/dtk-21.10.1/lib
mpirun --allow-run-as-root -np 4 single_squad4.sh
#!/bin/bash
#export LD_LIBRARY_PATH=/public/home/hepj/job_env/apps/dtk-21.10.1/lib
mpirun --allow-run-as-root -np 4 single_squad4_fp16.sh
#!/bin/bash
mpirun --allow-run-as-root -np 1 single_squad_fp16.sh
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#! /bin/bash
set -euo pipefail
print_usage() {
cat << EOF
${0} [options] [--] COMMAND [ARG...]
Control binding policy for each task. Assumes one rank will be launched for each GPU.
Options:
--cpu=MODE
* exclusive -- bind each rank to an exclusive set of cores near its GPU
* exclusive,nosmt -- bind each rank to an exclusive set of cores near its GPU, without hyperthreading
* node -- bind each rank to all cores in the NUMA node nearest its GPU [default]
* *.sh -- bind each rank using the bash associative array bind_cpu_cores or bind_cpu_nodes from a file
* off -- don't bind
--mem=MODE
* node -- bind each rank to the nearest NUMA node [default]
* *.sh -- bind each rank using the bash associative array bind_mem from a file
* off -- don't bind
--ib=MODE
* single -- bind each rank to a single IB device near its GPU
* off -- donot bind [default]
--cluster=CLUSTER
Select which cluster is being used. May be required if system params cannot be detected.
EOF
}
################################################################################
# Argument parsing
################################################################################
cpu_mode='node'
mem_mode='node'
ib_mode='off'
cluster=''
while [ $# -gt 0 ]; do
case "$1" in
-h|--help) print_usage ; exit 0 ;;
--cpu=*) cpu_mode="${1/*=/}"; shift ;;
--cpu) cpu_mode="$2"; shift 2 ;;
--mem=*) mem_mode="${1/*=/}"; shift ;;
--mem) mem_mode="$2"; shift 2 ;;
--ib=*) ib_mode="${1/*=/}"; shift ;;
--ib) ib_mode="$2"; shift 2 ;;
--cluster=*) cluster="${1/*=/}"; shift ;;
--cluster) cluster="$2"; shift 2 ;;
--) shift; break ;;
*) break ;;
esac
done
if [ $# -lt 1 ]; then
echo 'ERROR: no command given' 2>&1
print_usage
exit 1
fi
################################################################################
# Get system params
################################################################################
# LOCAL_RANK is set with an enroot hook for Pytorch containers
# SLURM_LOCALID is set by Slurm
# OMPI_COMM_WORLD_LOCAL_RANK is set by mpirun
readonly local_rank="${LOCAL_RANK:=${SLURM_LOCALID:=${OMPI_COMM_WORLD_LOCAL_RANK:-}}}"
if [ -z "${local_rank}" ]; then
echo 'ERROR: cannot read LOCAL_RANK from env' >&2
exit 1
fi
num_gpus=$(nvidia-smi -i 0 --query-gpu=count --format=csv,noheader,nounits)
if [ "${local_rank}" -ge "${num_gpus}" ]; then
echo "ERROR: local rank is ${local_rank}, but there are only ${num_gpus} gpus available" >&2
exit 1
fi
get_lscpu_value() {
awk -F: "(\$1 == \"${1}\"){gsub(/ /, \"\", \$2); print \$2; found=1} END{exit found!=1}"
}
lscpu_out=$(lscpu)
num_sockets=$(get_lscpu_value 'Socket(s)' <<< "${lscpu_out}")
num_nodes=$(get_lscpu_value 'NUMA node(s)' <<< "${lscpu_out}")
cores_per_socket=$(get_lscpu_value 'Core(s) per socket' <<< "${lscpu_out}")
echo "num_sockets = ${num_sockets} num_nodes=${num_nodes} cores_per_socket=${cores_per_socket}"
readonly cores_per_node=$(( (num_sockets * cores_per_socket) / num_nodes ))
if [ ${num_gpus} -gt 1 ]; then
readonly gpus_per_node=$(( num_gpus / num_nodes ))
else
readonly gpus_per_node=1
fi
readonly cores_per_gpu=$(( cores_per_node / gpus_per_node ))
readonly local_node=$(( local_rank / gpus_per_node ))
declare -a ibdevs=()
case "${cluster}" in
circe)
# Need to specialize for circe because IB detection is hard
ibdevs=(mlx5_1 mlx5_2 mlx5_3 mlx5_4 mlx5_7 mlx5_8 mlx5_9 mlx5_10)
;;
selene)
# Need to specialize for selene because IB detection is hard
ibdevs=(mlx5_0 mlx5_1 mlx5_2 mlx5_3 mlx5_6 mlx5_7 mlx5_8 mlx5_9)
;;
'')
if ibstat_out="$(ibstat -l 2>/dev/null | sort -V)" ; then
mapfile -t ibdevs <<< "${ibstat_out}"
fi
;;
*)
echo "ERROR: Unknown cluster '${cluster}'" >&2
exit 1
;;
esac
readonly num_ibdevs="${#ibdevs[@]}"
################################################################################
# Setup for exec
################################################################################
declare -a numactl_args=()
case "${cpu_mode}" in
exclusive)
numactl_args+=( "$(printf -- "--physcpubind=%u-%u,%u-%u" \
$(( local_rank * cores_per_gpu )) \
$(( (local_rank + 1) * cores_per_gpu - 1 )) \
$(( local_rank * cores_per_gpu + (cores_per_gpu * gpus_per_node * num_nodes) )) \
$(( (local_rank + 1) * cores_per_gpu + (cores_per_gpu * gpus_per_node * num_nodes) - 1 )) \
)" )
;;
exclusive,nosmt)
numactl_args+=( "$(printf -- "--physcpubind=%u-%u" \
$(( local_rank * cores_per_gpu )) \
$(( (local_rank + 1) * cores_per_gpu - 1 )) \
)" )
;;
node)
numactl_args+=( "--cpunodebind=${local_node}" )
;;
*.sh)
source "${cpu_mode}"
if [ -n "${bind_cpu_cores:-}" ]; then
numactl_args+=( "--physcpubind=${bind_cpu_cores[${local_rank}]}" )
elif [ -n "${bind_cpu_nodes:-}" ]; then
numactl_args+=( "--cpunodebind=${bind_cpu_nodes[${local_rank}]}" )
else
echo "ERROR: invalid CPU affinity file ${cpu_mode}." >&2
exit 1
fi
;;
off|'')
;;
*)
echo "ERROR: invalid cpu mode '${cpu_mode}'" 2>&1
print_usage
exit 1
;;
esac
case "${mem_mode}" in
node)
numactl_args+=( "--membind=${local_node}" )
;;
*.sh)
source "${mem_mode}"
if [ -z "${bind_mem:-}" ]; then
echo "ERROR: invalid memory affinity file ${mem_mode}." >&2
exit 1
fi
numactl_args+=( "--membind=${bind_mem[${local_rank}]}" )
;;
off|'')
;;
*)
echo "ERROR: invalid mem mode '${mem_mode}'" 2>&1
print_usage
exit 1
;;
esac
case "${ib_mode}" in
single)
if [ "${num_ibdevs}" -eq 0 ]; then
echo "WARNING: used '$0 --ib=single', but there are 0 IB devices available; skipping IB binding." 2>&1
else
readonly ibdev="${ibdevs[$(( local_rank * num_ibdevs / num_gpus ))]}"
export OMPI_MCA_btl_openib_if_include="${OMPI_MCA_btl_openib_if_include-$ibdev}"
fi
;;
off|'')
;;
*)
echo "ERROR: invalid ib mode '${ib_mode}'" 2>&1
print_usage
exit 1
;;
esac
################################################################################
# Exec
################################################################################
if [ "${#numactl_args[@]}" -gt 0 ] ; then
set -x
exec numactl "${numactl_args[@]}" -- "${@}"
else
exec "${@}"
fi
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import subprocess
import os
import socket
from argparse import ArgumentParser, REMAINDER
import torch
def parse_args():
"""
Helper function parsing the command line options
@retval ArgumentParser
"""
parser = ArgumentParser(description="PyTorch distributed training launch "
"helper utilty that will spawn up "
"multiple distributed processes")
# Optional arguments for the launch helper
parser.add_argument("--nnodes", type=int, default=1,
help="The number of nodes to use for distributed "
"training")
parser.add_argument("--node_rank", type=int, default=0,
help="The rank of the node for multi-node distributed "
"training")
parser.add_argument("--nproc_per_node", type=int, default=1,
help="The number of processes to launch on each node, "
"for GPU training, this is recommended to be set "
"to the number of GPUs in your system so that "
"each process can be bound to a single GPU.")
parser.add_argument("--master_addr", default="127.0.0.1", type=str,
help="Master node (rank 0)'s address, should be either "
"the IP address or the hostname of node 0, for "
"single node multi-proc training, the "
"--master_addr can simply be 127.0.0.1")
parser.add_argument("--master_port", default=29500, type=int,
help="Master node (rank 0)'s free port that needs to "
"be used for communciation during distributed "
"training")
parser.add_argument('--no_hyperthreads', action='store_true',
help='Flag to disable binding to hyperthreads')
parser.add_argument('--no_membind', action='store_true',
help='Flag to disable memory binding')
# non-optional arguments for binding
parser.add_argument("--nsockets_per_node", type=int, required=True,
help="Number of CPU sockets on a node")
parser.add_argument("--ncores_per_socket", type=int, required=True,
help="Number of CPU cores per socket")
# positional
parser.add_argument("training_script", type=str,
help="The full path to the single GPU training "
"program/script to be launched in parallel, "
"followed by all the arguments for the "
"training script")
# rest from the training program
parser.add_argument('training_script_args', nargs=REMAINDER)
return parser.parse_args()
def main():
args = parse_args()
# variables for numactrl binding
NSOCKETS = args.nsockets_per_node
NGPUS_PER_SOCKET = (args.nproc_per_node // args.nsockets_per_node) + (1 if (args.nproc_per_node % args.nsockets_per_node) else 0)
NCORES_PER_GPU = args.ncores_per_socket // NGPUS_PER_SOCKET
# world size in terms of number of processes
dist_world_size = args.nproc_per_node * args.nnodes
# set PyTorch distributed related environmental variables
current_env = os.environ.copy()
current_env["MASTER_ADDR"] = args.master_addr
current_env["MASTER_PORT"] = str(args.master_port)
current_env["WORLD_SIZE"] = str(dist_world_size)
processes = []
for local_rank in range(0, args.nproc_per_node):
# each process's rank
dist_rank = args.nproc_per_node * args.node_rank + local_rank
current_env["RANK"] = str(dist_rank)
# form numactrl binding command
cpu_ranges = [local_rank * NCORES_PER_GPU,
(local_rank + 1) * NCORES_PER_GPU - 1,
local_rank * NCORES_PER_GPU + (NCORES_PER_GPU * NGPUS_PER_SOCKET * NSOCKETS),
(local_rank + 1) * NCORES_PER_GPU + (NCORES_PER_GPU * NGPUS_PER_SOCKET * NSOCKETS) - 1]
numactlargs = []
if args.no_hyperthreads:
numactlargs += [ "--physcpubind={}-{}".format(*cpu_ranges[0:2]) ]
else:
numactlargs += [ "--physcpubind={}-{},{}-{}".format(*cpu_ranges) ]
if not args.no_membind:
memnode = local_rank // NGPUS_PER_SOCKET
numactlargs += [ "--membind={}".format(memnode) ]
# spawn the processes
cmd = [ "/usr/bin/numactl" ] \
+ numactlargs \
+ [ sys.executable,
"-u",
args.training_script,
"--local_rank={}".format(local_rank)
] \
+ args.training_script_args
process = subprocess.Popen(cmd, env=current_env)
processes.append(process)
for process in processes:
process.wait()
if __name__ == "__main__":
main()
# Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#1 DGX1 phase1
bert--DGX1:
<<: *BERT_ON_CLUSTER
<<: *DGX1
variables:
<<: *DGX1_VARS
NNODES: "1"
BATCHSIZE: "8192"
LR: "6e-3"
GRADIENT_STEPS: "512"
PHASE: "1"
#4 DGX1 phase1
bert--DGX1_4x8x16x128:
<<: *BERT_ON_CLUSTER
<<: *DGX1
variables:
<<: *DGX1_VARS
NNODES: "4"
BATCHSIZE: "2048"
LR: "6e-3"
GRADIENT_STEPS: "128"
PHASE: "1"
#16 DGX1 phase1
bert--DGX1_16x8x16x32:
<<: *BERT_ON_CLUSTER
<<: *DGX1
variables:
<<: *DGX1_VARS
NNODES: "16"
BATCHSIZE: "512"
LR: "6e-3"
GRADIENT_STEPS: "32"
PHASE: "1"
#1 DGX2 phase1
bert--DGX2:
<<: *BERT_ON_CLUSTER
<<: *DGX2
variables:
<<: *DGX2_VARS
NNODES: "1"
BATCHSIZE: "4096"
LR: "6e-3"
GRADIENT_STEPS: "64"
PHASE: "1"
#4 DGX2 phase1
bert--DGX2_4x16x64x16:
<<: *BERT_ON_CLUSTER
<<: *DGX2
variables:
<<: *DGX2_VARS
NNODES: "4"
BATCHSIZE: "1024"
LR: "6e-3"
GRADIENT_STEPS: "16"
PHASE: "1"
#16 DGX2 phase1
bert--DGX2_16x16x64x4:
<<: *BERT_ON_CLUSTER
<<: *DGX2
variables:
<<: *DGX2_VARS
NNODES: "16"
BATCHSIZE: "256"
LR: "6e-3"
GRADIENT_STEPS: "4"
PHASE: "1"
#64 DGX2 phase1
bert--DGX2_64x16x64:
<<: *BERT_ON_CLUSTER
<<: *DGX2
variables:
<<: *DGX2_VARS
NNODES: "64"
BATCHSIZE: "64"
LR: "6e-3"
GRADIENT_STEPS: "1"
PHASE: "1"
#1 DGX1 phase2
bert--DGX1_1x8x4x1024:
<<: *BERT_ON_CLUSTER
<<: *DGX1
variables:
<<: *DGX1_VARS
NNODES: "1"
BATCHSIZE: "4096"
LR: "4e-3"
GRADIENT_STEPS: "1024"
PHASE: "2"
#4 DGX1 phase2
bert--DGX1_4x8x4x256:
<<: *BERT_ON_CLUSTER
<<: *DGX1
variables:
<<: *DGX1_VARS
NNODES: "4"
BATCHSIZE: "1024"
LR: "4e-3"
GRADIENT_STEPS: "256"
PHASE: "2"
#16 DGX1 phase2
bert--DGX1_16x8x4x64:
<<: *BERT_ON_CLUSTER
<<: *DGX1
variables:
<<: *DGX1_VARS
NNODES: "16"
BATCHSIZE: "256"
LR: "4e-3"
GRADIENT_STEPS: "64"
PHASE: "2"
#1 DGX2 phase2
bert--DGX2_1x16x8x256:
<<: *BERT_ON_CLUSTER
<<: *DGX2
variables:
<<: *DGX2_VARS
NNODES: "1"
BATCHSIZE: "2048"
LR: "4e-3"
GRADIENT_STEPS: "256"
PHASE: "2"
#4 DGX2 phase2
bert--DGX2_4x16x8x64:
<<: *BERT_ON_CLUSTER
<<: *DGX2
variables:
<<: *DGX2_VARS
NNODES: "4"
BATCHSIZE: "512"
LR: "4e-3"
GRADIENT_STEPS: "64"
PHASE: "2"
#16 DGX2 phase2
bert--DGX2_16x16x8x16:
<<: *BERT_ON_CLUSTER
<<: *DGX2
variables:
<<: *DGX2_VARS
NNODES: "16"
BATCHSIZE: "128"
LR: "4e-3"
GRADIENT_STEPS: "16"
PHASE: "2"
#64 DGX2 phase2
bert--DGX2_64x16x8x4:
<<: *BERT_ON_CLUSTER
<<: *DGX2
variables:
<<: *DGX2_VARS
NNODES: "64"
BATCHSIZE: "32"
LR: "4e-3"
GRADIENT_STEPS: "4"
PHASE: "2"
# coding=utf-8
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Create masked LM/next sentence masked_lm TF examples for BERT."""
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import logging
import os
import random
from io import open
import h5py
import numpy as np
from tqdm import tqdm, trange
from tokenization import BertTokenizer
import tokenization as tokenization
import random
import collections
class TrainingInstance(object):
"""A single training instance (sentence pair)."""
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
is_random_next):
self.tokens = tokens
self.segment_ids = segment_ids
self.is_random_next = is_random_next
self.masked_lm_positions = masked_lm_positions
self.masked_lm_labels = masked_lm_labels
def __str__(self):
s = ""
s += "tokens: %s\n" % (" ".join(
[tokenization.printable_text(x) for x in self.tokens]))
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
s += "is_random_next: %s\n" % self.is_random_next
s += "masked_lm_positions: %s\n" % (" ".join(
[str(x) for x in self.masked_lm_positions]))
s += "masked_lm_labels: %s\n" % (" ".join(
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
s += "\n"
return s
def __repr__(self):
return self.__str__()
def write_instance_to_example_file(instances, tokenizer, max_seq_length,
max_predictions_per_seq, output_file):
"""Create TF example files from `TrainingInstance`s."""
total_written = 0
features = collections.OrderedDict()
num_instances = len(instances)
features["input_ids"] = np.zeros([num_instances, max_seq_length], dtype="int32")
features["input_mask"] = np.zeros([num_instances, max_seq_length], dtype="int32")
features["segment_ids"] = np.zeros([num_instances, max_seq_length], dtype="int32")
features["masked_lm_positions"] = np.zeros([num_instances, max_predictions_per_seq], dtype="int32")
features["masked_lm_ids"] = np.zeros([num_instances, max_predictions_per_seq], dtype="int32")
features["next_sentence_labels"] = np.zeros(num_instances, dtype="int32")
for inst_index, instance in enumerate(tqdm(instances)):
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
input_mask = [1] * len(input_ids)
segment_ids = list(instance.segment_ids)
assert len(input_ids) <= max_seq_length
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
masked_lm_positions = list(instance.masked_lm_positions)
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
masked_lm_weights = [1.0] * len(masked_lm_ids)
while len(masked_lm_positions) < max_predictions_per_seq:
masked_lm_positions.append(0)
masked_lm_ids.append(0)
masked_lm_weights.append(0.0)
next_sentence_label = 1 if instance.is_random_next else 0
features["input_ids"][inst_index] = input_ids
features["input_mask"][inst_index] = input_mask
features["segment_ids"][inst_index] = segment_ids
features["masked_lm_positions"][inst_index] = masked_lm_positions
features["masked_lm_ids"][inst_index] = masked_lm_ids
features["next_sentence_labels"][inst_index] = next_sentence_label
total_written += 1
# if inst_index < 20:
# tf.logging.info("*** Example ***")
# tf.logging.info("tokens: %s" % " ".join(
# [tokenization.printable_text(x) for x in instance.tokens]))
# for feature_name in features.keys():
# feature = features[feature_name]
# values = []
# if feature.int64_list.value:
# values = feature.int64_list.value
# elif feature.float_list.value:
# values = feature.float_list.value
# tf.logging.info(
# "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
print("saving data")
f= h5py.File(output_file, 'w')
f.create_dataset("input_ids", data=features["input_ids"], dtype='i4', compression='gzip')
f.create_dataset("input_mask", data=features["input_mask"], dtype='i1', compression='gzip')
f.create_dataset("segment_ids", data=features["segment_ids"], dtype='i1', compression='gzip')
f.create_dataset("masked_lm_positions", data=features["masked_lm_positions"], dtype='i4', compression='gzip')
f.create_dataset("masked_lm_ids", data=features["masked_lm_ids"], dtype='i4', compression='gzip')
f.create_dataset("next_sentence_labels", data=features["next_sentence_labels"], dtype='i1', compression='gzip')
f.flush()
f.close()
def create_training_instances(input_files, tokenizer, max_seq_length,
dupe_factor, short_seq_prob, masked_lm_prob,
max_predictions_per_seq, rng):
"""Create `TrainingInstance`s from raw text."""
all_documents = [[]]
# Input file format:
# (1) One sentence per line. These should ideally be actual sentences, not
# entire paragraphs or arbitrary spans of text. (Because we use the
# sentence boundaries for the "next sentence prediction" task).
# (2) Blank lines between documents. Document boundaries are needed so
# that the "next sentence prediction" task doesn't span between documents.
for input_file in input_files:
print("creating instance from {}".format(input_file))
with open(input_file, "r") as reader:
while True:
line = tokenization.convert_to_unicode(reader.readline())
if not line:
break
line = line.strip()
# Empty lines are used as document delimiters
if not line:
all_documents.append([])
tokens = tokenizer.tokenize(line)
if tokens:
all_documents[-1].append(tokens)
# Remove empty documents
all_documents = [x for x in all_documents if x]
rng.shuffle(all_documents)
vocab_words = list(tokenizer.vocab.keys())
instances = []
for _ in range(dupe_factor):
for document_index in range(len(all_documents)):
instances.extend(
create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
rng.shuffle(instances)
return instances
def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
"""Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index]
# Account for [CLS], [SEP], [SEP]
max_num_tokens = max_seq_length - 3
# We *usually* want to fill up the entire sequence since we are padding
# to `max_seq_length` anyways, so short sequences are generally wasted
# computation. However, we *sometimes*
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
# sequences to minimize the mismatch between pre-training and fine-tuning.
# The `target_seq_length` is just a rough target however, whereas
# `max_seq_length` is a hard limit.
target_seq_length = max_num_tokens
if rng.random() < short_seq_prob:
target_seq_length = rng.randint(2, max_num_tokens)
# We DON'T just concatenate all of the tokens from a document into a long
# sequence and choose an arbitrary split point because this would make the
# next sentence prediction task too easy. Instead, we split the input into
# segments "A" and "B" based on the actual "sentences" provided by the user
# input.
instances = []
current_chunk = []
current_length = 0
i = 0
while i < len(document):
segment = document[i]
current_chunk.append(segment)
current_length += len(segment)
if i == len(document) - 1 or current_length >= target_seq_length:
if current_chunk:
# `a_end` is how many segments from `current_chunk` go into the `A`
# (first) sentence.
a_end = 1
if len(current_chunk) >= 2:
a_end = rng.randint(1, len(current_chunk) - 1)
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
tokens_b = []
# Random next
is_random_next = False
if len(current_chunk) == 1 or rng.random() < 0.5:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a)
# This should rarely go for more than one iteration for large
# corpora. However, just to be careful, we try to make sure that
# the random document is not the same as the document
# we're processing.
for _ in range(10):
random_document_index = rng.randint(0, len(all_documents) - 1)
if random_document_index != document_index:
break
#If picked random document is the same as the current document
if random_document_index == document_index:
is_random_next = False
random_document = all_documents[random_document_index]
random_start = rng.randint(0, len(random_document) - 1)
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
break
# We didn't actually use these segments so we "put them back" so
# they don't go to waste.
num_unused_segments = len(current_chunk) - a_end
i -= num_unused_segments
# Actual next
else:
is_random_next = False
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
(tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
instance = TrainingInstance(
tokens=tokens,
segment_ids=segment_ids,
is_random_next=is_random_next,
masked_lm_positions=masked_lm_positions,
masked_lm_labels=masked_lm_labels)
instances.append(instance)
current_chunk = []
current_length = 0
i += 1
return instances
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng):
"""Creates the predictions for the masked LM objective."""
cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
cand_indexes.append(i)
rng.shuffle(cand_indexes)
output_tokens = list(tokens)
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
if index in covered_indexes:
continue
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_positions = []
masked_lm_labels = []
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels)
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
"""Truncates a pair of sequences to a maximum sequence length."""
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_num_tokens:
break
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
assert len(trunc_tokens) >= 1
# We want to sometimes truncate from the front and sometimes from the
# back to add more randomness and avoid biases.
if rng.random() < 0.5:
del trunc_tokens[0]
else:
trunc_tokens.pop()
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--vocab_file",
default=None,
type=str,
required=True,
help="The vocabulary the BERT model will train on.")
parser.add_argument("--input_file",
default=None,
type=str,
required=True,
help="The input train corpus. can be directory with .txt files or a path to a single file")
parser.add_argument("--output_file",
default=None,
type=str,
required=True,
help="The output file where the model checkpoints will be written.")
## Other parameters
# str
parser.add_argument("--bert_model", default="bert-large-uncased", type=str, required=False,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
#int
parser.add_argument("--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--dupe_factor",
default=10,
type=int,
help="Number of times to duplicate the input data (with different masks).")
parser.add_argument("--max_predictions_per_seq",
default=20,
type=int,
help="Maximum sequence length.")
# floats
parser.add_argument("--masked_lm_prob",
default=0.15,
type=float,
help="Masked LM probability.")
parser.add_argument("--short_seq_prob",
default=0.1,
type=float,
help="Probability to create a sequence shorter than maximum sequence length")
parser.add_argument("--do_lower_case",
action='store_true',
default=True,
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument('--random_seed',
type=int,
default=12345,
help="random seed for initialization")
args = parser.parse_args()
tokenizer = BertTokenizer(args.vocab_file, do_lower_case=args.do_lower_case, max_len=512)
input_files = []
if os.path.isfile(args.input_file):
input_files.append(args.input_file)
elif os.path.isdir(args.input_file):
input_files = [os.path.join(args.input_file, f) for f in os.listdir(args.input_file) if (os.path.isfile(os.path.join(args.input_file, f)) and f.endswith('.txt') )]
else:
raise ValueError("{} is not a valid path".format(args.input_file))
rng = random.Random(args.random_seed)
instances = create_training_instances(
input_files, tokenizer, args.max_seq_length, args.dupe_factor,
args.short_seq_prob, args.masked_lm_prob, args.max_predictions_per_seq,
rng)
output_file = args.output_file
write_instance_to_example_file(instances, tokenizer, args.max_seq_length,
args.max_predictions_per_seq, output_file)
if __name__ == "__main__":
main()
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
class BooksDownloader:
def __init__(self, save_path):
self.save_path = save_path
pass
def download(self):
bookscorpus_download_command = 'python3 /workspace/bookcorpus/download_files.py --list /workspace/bookcorpus/url_list.jsonl --out'
bookscorpus_download_command += ' ' + self.save_path + '/bookscorpus'
bookscorpus_download_command += ' --trash-bad-count'
bookscorpus_download_process = subprocess.run(bookscorpus_download_command, shell=True, check=True)
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os
class BookscorpusTextFormatting:
def __init__(self, books_path, output_filename, recursive = False):
self.books_path = books_path
self.recursive = recursive
self.output_filename = output_filename
# This puts one book per line
def merge(self):
with open(self.output_filename, mode='w', newline='\n') as ofile:
for filename in glob.glob(self.books_path + '/' + '*.txt', recursive=True):
with open(filename, mode='r', encoding='utf-8-sig', newline='\n') as file:
for line in file:
if line.strip() != '':
ofile.write(line.strip() + ' ')
ofile.write("\n\n")
\ No newline at end of file
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from GooglePretrainedWeightDownloader import GooglePretrainedWeightDownloader
from NVIDIAPretrainedWeightDownloader import NVIDIAPretrainedWeightDownloader
from WikiDownloader import WikiDownloader
from BooksDownloader import BooksDownloader
from GLUEDownloader import GLUEDownloader
from SquadDownloader import SquadDownloader
class Downloader:
def __init__(self, dataset_name, save_path):
self.dataset_name = dataset_name
self.save_path = save_path
def download(self):
if self.dataset_name == 'bookscorpus':
self.download_bookscorpus()
elif self.dataset_name == 'wikicorpus_en':
self.download_wikicorpus('en')
elif self.dataset_name == 'wikicorpus_zh':
self.download_wikicorpus('zh')
elif self.dataset_name == 'google_pretrained_weights':
self.download_google_pretrained_weights()
elif self.dataset_name == 'nvidia_pretrained_weights':
self.download_nvidia_pretrained_weights()
elif self.dataset_name in {'mrpc', 'sst-2'}:
self.download_glue(self.dataset_name)
elif self.dataset_name == 'squad':
self.download_squad()
elif self.dataset_name == 'all':
self.download_bookscorpus()
self.download_wikicorpus('en')
self.download_wikicorpus('zh')
self.download_google_pretrained_weights()
self.download_nvidia_pretrained_weights()
self.download_glue('mrpc')
self.download_glue('sst-2')
self.download_squad()
else:
print(self.dataset_name)
assert False, 'Unknown dataset_name provided to downloader'
def download_bookscorpus(self):
downloader = BooksDownloader(self.save_path)
downloader.download()
def download_wikicorpus(self, language):
downloader = WikiDownloader(language, self.save_path)
downloader.download()
def download_google_pretrained_weights(self):
downloader = GooglePretrainedWeightDownloader(self.save_path)
downloader.download()
def download_nvidia_pretrained_weights(self):
downloader = NVIDIAPretrainedWeightDownloader(self.save_path)
downloader.download()
def download_glue(self, task_name):
downloader = GLUEDownloader(self.save_path)
downloader.download(task_name)
def download_squad(self):
downloader = SquadDownloader(self.save_path)
downloader.download()
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import wget
from pathlib import Path
def mkdir(path):
Path(path).mkdir(parents=True, exist_ok=True)
class GLUEDownloader:
def __init__(self, save_path):
self.save_path = save_path + '/glue'
def download(self, task_name):
mkdir(self.save_path)
if task_name in {'mrpc', 'mnli'}:
task_name = task_name.upper()
elif task_name == 'cola':
task_name = 'CoLA'
else: # SST-2
assert task_name == 'sst-2'
task_name = 'SST'
wget.download(
'https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py',
out=self.save_path,
)
sys.path.append(self.save_path)
import download_glue_data
download_glue_data.main(
['--data_dir', self.save_path, '--tasks', task_name])
sys.path.pop()
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import os
import urllib.request
import zipfile
class GooglePretrainedWeightDownloader:
def __init__(self, save_path):
self.save_path = save_path + '/google_pretrained_weights'
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
# Download urls
self.model_urls = {
'bert_base_uncased': ('https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip', 'uncased_L-12_H-768_A-12.zip'),
'bert_large_uncased': ('https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip', 'uncased_L-24_H-1024_A-16.zip'),
'bert_base_cased': ('https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip', 'cased_L-12_H-768_A-12.zip'),
'bert_large_cased': ('https://storage.googleapis.com/bert_models/2018_10_18/cased_L-24_H-1024_A-16.zip', 'cased_L-24_H-1024_A-16.zip'),
'bert_base_multilingual_cased': ('https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip', 'multi_cased_L-12_H-768_A-12.zip'),
'bert_large_multilingual_uncased': ('https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip', 'multilingual_L-12_H-768_A-12.zip'),
'bert_base_chinese': ('https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip', 'chinese_L-12_H-768_A-12.zip')
}
# SHA256sum verification for file download integrity (and checking for changes from the download source over time)
self.bert_base_uncased_sha = {
'bert_config.json': '7b4e5f53efbd058c67cda0aacfafb340113ea1b5797d9ce6ee411704ba21fcbc',
'bert_model.ckpt.data-00000-of-00001': '58580dc5e0bf0ae0d2efd51d0e8272b2f808857f0a43a88aaf7549da6d7a8a84',
'bert_model.ckpt.index': '04c1323086e2f1c5b7c0759d8d3e484afbb0ab45f51793daab9f647113a0117b',
'bert_model.ckpt.meta': 'dd5682170a10c3ea0280c2e9b9a45fee894eb62da649bbdea37b38b0ded5f60e',
'vocab.txt': '07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3',
}
self.bert_large_uncased_sha = {
'bert_config.json': 'bfa42236d269e2aeb3a6d30412a33d15dbe8ea597e2b01dc9518c63cc6efafcb',
'bert_model.ckpt.data-00000-of-00001': 'bc6b3363e3be458c99ecf64b7f472d2b7c67534fd8f564c0556a678f90f4eea1',
'bert_model.ckpt.index': '68b52f2205ffc64dc627d1120cf399c1ef1cbc35ea5021d1afc889ffe2ce2093',
'bert_model.ckpt.meta': '6fcce8ff7628f229a885a593625e3d5ff9687542d5ef128d9beb1b0c05edc4a1',
'vocab.txt': '07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3',
}
self.bert_base_cased_sha = {
'bert_config.json': 'f11dfb757bea16339a33e1bf327b0aade6e57fd9c29dc6b84f7ddb20682f48bc',
'bert_model.ckpt.data-00000-of-00001': '734d5a1b68bf98d4e9cb6b6692725d00842a1937af73902e51776905d8f760ea',
'bert_model.ckpt.index': '517d6ef5c41fc2ca1f595276d6fccf5521810d57f5a74e32616151557790f7b1',
'bert_model.ckpt.meta': '5f8a9771ff25dadd61582abb4e3a748215a10a6b55947cbb66d0f0ba1694be98',
'vocab.txt': 'eeaa9875b23b04b4c54ef759d03db9d1ba1554838f8fb26c5d96fa551df93d02',
}
self.bert_large_cased_sha = {
'bert_config.json': '7adb2125c8225da495656c982fd1c5f64ba8f20ad020838571a3f8a954c2df57',
'bert_model.ckpt.data-00000-of-00001': '6ff33640f40d472f7a16af0c17b1179ca9dcc0373155fb05335b6a4dd1657ef0',
'bert_model.ckpt.index': 'ef42a53f577fbe07381f4161b13c7cab4f4fc3b167cec6a9ae382c53d18049cf',
'bert_model.ckpt.meta': 'd2ddff3ed33b80091eac95171e94149736ea74eb645e575d942ec4a5e01a40a1',
'vocab.txt': 'eeaa9875b23b04b4c54ef759d03db9d1ba1554838f8fb26c5d96fa551df93d02',
}
self.bert_base_multilingual_cased_sha = {
'bert_config.json': 'e76c3964bc14a8bb37a5530cdc802699d2f4a6fddfab0611e153aa2528f234f0',
'bert_model.ckpt.data-00000-of-00001': '55b8a2df41f69c60c5180e50a7c31b7cdf6238909390c4ddf05fbc0d37aa1ac5',
'bert_model.ckpt.index': '7d8509c2a62b4e300feb55f8e5f1eef41638f4998dd4d887736f42d4f6a34b37',
'bert_model.ckpt.meta': '95e5f1997e8831f1c31e5cf530f1a2e99f121e9cd20887f2dce6fe9e3343e3fa',
'vocab.txt': 'fe0fda7c425b48c516fc8f160d594c8022a0808447475c1a7c6d6479763f310c',
}
self.bert_large_multilingual_uncased_sha = {
'bert_config.json': '49063bb061390211d2fdd108cada1ed86faa5f90b80c8f6fdddf406afa4c4624',
'bert_model.ckpt.data-00000-of-00001': '3cd83912ebeb0efe2abf35c9f1d5a515d8e80295e61c49b75c8853f756658429',
'bert_model.ckpt.index': '87c372c1a3b1dc7effaaa9103c80a81b3cbab04c7933ced224eec3b8ad2cc8e7',
'bert_model.ckpt.meta': '27f504f34f02acaa6b0f60d65195ec3e3f9505ac14601c6a32b421d0c8413a29',
'vocab.txt': '87b44292b452f6c05afa49b2e488e7eedf79ea4f4c39db6f2f4b37764228ef3f',
}
self.bert_base_chinese_sha = {
'bert_config.json': '7aaad0335058e2640bcb2c2e9a932b1cd9da200c46ea7b8957d54431f201c015',
'bert_model.ckpt.data-00000-of-00001': '756699356b78ad0ef1ca9ba6528297bcb3dd1aef5feadd31f4775d7c7fc989ba',
'bert_model.ckpt.index': '46315546e05ce62327b3e2cd1bed22836adcb2ff29735ec87721396edb21b82e',
'bert_model.ckpt.meta': 'c0f8d51e1ab986604bc2b25d6ec0af7fd21ff94cf67081996ec3f3bf5d823047',
'vocab.txt': '45bbac6b341c319adc98a532532882e91a9cefc0329aa57bac9ae761c27b291c',
}
# Relate SHA to urls for loop below
self.model_sha = {
'bert_base_uncased': self.bert_base_uncased_sha,
'bert_large_uncased': self.bert_large_uncased_sha,
'bert_base_cased': self.bert_base_cased_sha,
'bert_large_cased': self.bert_large_cased_sha,
'bert_base_multilingual_cased': self.bert_base_multilingual_cased_sha,
'bert_large_multilingual_uncased': self.bert_large_multilingual_uncased_sha,
'bert_base_chinese': self.bert_base_chinese_sha
}
# Helper to get sha256sum of a file
def sha256sum(self, filename):
h = hashlib.sha256()
b = bytearray(128*1024)
mv = memoryview(b)
with open(filename, 'rb', buffering=0) as f:
for n in iter(lambda : f.readinto(mv), 0):
h.update(mv[:n])
return h.hexdigest()
def download(self):
# Iterate over urls: download, unzip, verify sha256sum
found_mismatch_sha = False
for model in self.model_urls:
url = self.model_urls[model][0]
file = self.save_path + '/' + self.model_urls[model][1]
print('Downloading', url)
response = urllib.request.urlopen(url)
with open(file, 'wb') as handle:
handle.write(response.read())
print('Unzipping', file)
zip = zipfile.ZipFile(file, 'r')
zip.extractall(self.save_path)
zip.close()
sha_dict = self.model_sha[model]
for extracted_file in sha_dict:
sha = sha_dict[extracted_file]
if sha != self.sha256sum(file[:-4] + '/' + extracted_file):
found_mismatch_sha = True
print('SHA256sum does not match on file:', extracted_file, 'from download url:', url)
else:
print(file[:-4] + '/' + extracted_file, '\t', 'verified')
if not found_mismatch_sha:
print("All downloads pass sha256sum verification.")
def serialize(self):
pass
def deserialize(self):
pass
def listAvailableWeights(self):
print("Available Weight Datasets")
for item in self.model_urls:
print(item)
def listLocallyStoredWeights(self):
pass
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
class NVIDIAPretrainedWeightDownloader:
def __init__(self, save_path):
self.save_path = save_path + '/nvidia_pretrained_weights'
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
pass
def download(self):
assert False, 'NVIDIAPretrainedWeightDownloader not implemented yet.'
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment