"vscode:/vscode.git/clone" did not exist on "492143bf32b25848300dcc18bd51fef6c25d02d7"
Commit cea8529d authored by zhangqha@sugon.com's avatar zhangqha@sugon.com
Browse files

上传代码

parent f5e0c216
This diff is collapsed.
File added
File added
from torchvision.transforms import autoaugment, transforms
class ClassificationPresetTrain:
def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), hflip_prob=0.5,
auto_augment_policy=None, random_erase_prob=0.0):
trans = [transforms.RandomResizedCrop(crop_size)]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
if random_erase_prob > 0:
trans.append(transforms.RandomErasing(p=random_erase_prob))
self.transforms = transforms.Compose(trans)
def __call__(self, img):
return self.transforms(img)
class ClassificationPresetEval:
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = transforms.Compose([
transforms.Resize(resize_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
def __call__(self, img):
return self.transforms(img)
sphinx==2.4.4
sphinx-gallery>=0.9.0
sphinx-copybutton>=0.3.1
matplotlib
numpy
-e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
export MIOPEN_DEBUG_CONV_IMPLICIT_GEMM=0
export HIP_VISIBLE_DEVICES=2
python val_onnx.py --test-only --data-path /parastor/DL_DATA/ImageNet-pytorch --model mobilenet_v3_large --b 1 --pretrained
"""Test script for torch module"""
import torch
import time
import tvm
from tvm.contrib.torch import compile
import torch.onnx as onnx
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.rand([1, 3, 224, 224]).to(device)
#checkpoint = torch.load("./yolov8n.pt")
#model.load(weights='yolov8n.pt')
#torch.onnx.export(model.predictor.model, x, "./yolov8n.onnx")
#model.export(format='onnx',imgsz=[384, 640], device="cuda")
model_jit = torch.jit.load("./model.pt")
option = {
"input_infos": [
("x", (1, 3, 224, 224)),
],
"default_dtype": "float32",
"export_dir": "pytorch_compiled",
"num_outputs": 1,
"tuning_n_trials": 0, # set zero to skip tuning
"tuning_log_file": "tuning.log",
"target": "rocm --libs=miopen,rocblas",
"device": tvm.rocm(),
}
pytorch_tvm_module = compile(model_jit, option)
print("Run PyTorch...")
for i in range(1):
t = time.time()
# module = torch.jit.load("./model_tvm.pt");
# outputs=module([x])
outputs = pytorch_tvm_module.forward([x])
torch.cuda.synchronize()
print(1000 * (time.time() - t))
print(outputs[0].shape)
This diff is collapsed.
This diff is collapsed.
from collections import defaultdict, deque, OrderedDict
import copy
import datetime
import hashlib
import time
import torch
import torch.distributed as dist
import errno
import os
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {}'.format(header, total_time_str))
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target[None])
res = []
for k in topk:
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
res.append(correct_k * (100.0 / batch_size))
return res
def mkdir(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
setup_for_distributed(args.rank == 0)
def average_checkpoints(inputs):
"""Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
Args:
inputs (List[str]): An iterable of string paths of checkpoints to load from.
Returns:
A dict of string keys mapping to various values. The 'model' key
from the returned dict should correspond to an OrderedDict mapping
string parameter names to torch Tensors.
"""
params_dict = OrderedDict()
params_keys = None
new_state = None
num_models = len(inputs)
for fpath in inputs:
with open(fpath, "rb") as f:
state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
),
)
# Copies over the settings from the first checkpoint
if new_state is None:
new_state = state
model_params = state["model"]
model_params_keys = list(model_params.keys())
if params_keys is None:
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError(
"For checkpoint {}, expected list of params: {}, "
"but found: {}".format(f, params_keys, model_params_keys)
)
for k in params_keys:
p = model_params[k]
if isinstance(p, torch.HalfTensor):
p = p.float()
if k not in params_dict:
params_dict[k] = p.clone()
# NOTE: clone() is needed in case of p is a shared parameter
else:
params_dict[k] += p
averaged_params = OrderedDict()
for k, v in params_dict.items():
averaged_params[k] = v
if averaged_params[k].is_floating_point():
averaged_params[k].div_(num_models)
else:
averaged_params[k] //= num_models
new_state["model"] = averaged_params
return new_state
def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=True):
"""
This method can be used to prepare weights files for new models. It receives as
input a model architecture and a checkpoint from the training script and produces
a file with the weights ready for release.
Examples:
from torchvision import models as M
# Classification
model = M.mobilenet_v3_large(pretrained=False)
print(store_model_weights(model, './class.pth'))
# Quantized Classification
model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False)
model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
_ = torch.quantization.prepare_qat(model, inplace=True)
print(store_model_weights(model, './qat.pth'))
# Object Detection
model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, pretrained_backbone=False)
print(store_model_weights(model, './obj.pth'))
# Segmentation
model = M.segmentation.deeplabv3_mobilenet_v3_large(pretrained=False, pretrained_backbone=False, aux_loss=True)
print(store_model_weights(model, './segm.pth', strict=False))
Args:
model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
checkpoint_path (str): The path of the checkpoint we will load.
checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
Default: "model".
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
Returns:
output_path (str): The location where the weights are saved.
"""
# Store the new model next to the checkpoint_path
checkpoint_path = os.path.abspath(checkpoint_path)
output_dir = os.path.dirname(checkpoint_path)
# Deep copy to avoid side-effects on the model object.
model = copy.deepcopy(model)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# Load the weights to the model to validate that everything works
# and remove unnecessary weights (such as auxiliaries, etc)
model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
tmp_path = os.path.join(output_dir, str(model.__hash__()))
torch.save(model.state_dict(), tmp_path)
sha256_hash = hashlib.sha256()
with open(tmp_path, "rb") as f:
# Read and update hash string value in blocks of 4K
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
hh = sha256_hash.hexdigest()
output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
os.replace(tmp_path, output_path)
return output_path
import datetime
import os
import time
import torch
import torch.utils.data
from torch import nn
import torchvision
import tvm
from tvm.contrib.torch import compile
import presets
import utils
import numpy as np
try:
from apex import amp
except ImportError:
amp = None
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))
header = 'Epoch: [{}]'.format(epoch)
for image, target in metric_logger.log_every(data_loader, print_freq, header):
start_time = time.time()
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
if apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = image.shape[0]
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
def evaluate(model, criterion, data_loader, device, print_freq=100):
model.eval()
x = torch.rand([1, 3, 224, 224]).cuda()
# 转换为TorchScript格式
scripted_model = torch.jit.trace(model,x)
# 将TorchScript模型保存为.pt文件
torch.jit.save(scripted_model, "model.pt")
model_jit = torch.jit.load("./model.pt")
option = {
"input_infos": [
("x", (1, 3, 224, 224)),
],
"default_dtype": "float32",
"export_dir": "pytorch_compiled",
"num_outputs": 1,
"tuning_n_trials": 0, # set zero to skip tuning
"tuning_log_file": "tuning.log",
"target": "rocm --libs=miopen,rocblas",
"device": tvm.rocm(),
}
pytorch_tvm_module = compile(model_jit, option)
preds = pytorch_tvm_module.forward([x])
preds = pytorch_tvm_module.forward([x])
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
with torch.no_grad():
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(image)
print("print ===orig output =====",output)
output = pytorch_tvm_module.forward([image])
print("print ===output =====",output)
print("print ===========target======",target)
loss = criterion(output, target)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
# FIXME need to take into account that the datasets
# could have been padded in distributed setup
batch_size = image.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5))
return metric_logger.acc1.global_avg
def _get_cache_path(filepath):
import hashlib
h = hashlib.sha1(filepath.encode()).hexdigest()
cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
cache_path = os.path.expanduser(cache_path)
return cache_path
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224)
print("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path)
else:
auto_augment_policy = getattr(args, "auto_augment", None)
random_erase_prob = getattr(args, "random_erase", 0.0)
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob))
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st)
print("Loading validation data")
cache_path = _get_cache_path(valdir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
else:
dataset_test = torchvision.datasets.ImageFolder(
valdir,
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size))
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)
print("Creating data loaders")
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
return dataset, dataset_test, train_sampler, test_sampler
def main(args):
if args.apex and amp is None:
raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training.")
if args.output_dir:
utils.mkdir(args.output_dir)
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size,
sampler=test_sampler, num_workers=args.workers, pin_memory=True)
print("Creating model")
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
criterion = nn.CrossEntropyLoss()
opt_name = args.opt.lower()
if opt_name == 'sgd':
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif opt_name == 'rmsprop':
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
else:
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
if args.apex:
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.apex_opt_level
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
return
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex)
lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device)
if args.output_dir:
checkpoint = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training', add_help=add_help)
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset')
parser.add_argument('--model', default='resnet18', help='model')
parser.add_argument('--device', default='cuda', help='device')
parser.add_argument('-b', '--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument(
"--cache-dataset",
dest="cache_dataset",
help="Cache the datasets for quicker initialization. It also serializes the transforms",
action="store_true",
)
parser.add_argument(
"--sync-bn",
dest="sync_bn",
help="Use sync batch norm",
action="store_true",
)
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true",
)
parser.add_argument(
"--pretrained",
dest="pretrained",
help="Use pre-trained models from the modelzoo",
action="store_true",
)
parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)')
parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)')
# Mixed precision training parameters
parser.add_argument('--apex', action='store_true',
help='Use apex for mixed precision training')
parser.add_argument('--apex-opt-level', default='O1', type=str,
help='For apex mixed precision training'
'O0 for FP32 training, O1 for mixed precision training.'
'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
)
# distributed training parameters
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)
import datetime
import os
import time
import torch
import torch.utils.data
from torch import nn
import torchvision
import tvm
from tvm.contrib.torch import compile
import presets
import utils
import numpy as np
try:
from apex import amp
except ImportError:
amp = None
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))
header = 'Epoch: [{}]'.format(epoch)
for image, target in metric_logger.log_every(data_loader, print_freq, header):
start_time = time.time()
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
if apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = image.shape[0]
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
def evaluate(model, criterion, data_loader, device, print_freq=100):
model.eval()
x = torch.rand([1, 3, 224, 224]).cuda()
# 转换为TorchScript格式
scripted_model = torch.jit.trace(model,x)
# TorchScript模型保存为.pt文件
torch.jit.save(scripted_model, "model.pt")
model_jit = torch.jit.load("./model.pt")
option = {
"input_infos": [
("x", (1, 3, 224, 224)),
],
"default_dtype": "float32",
"export_dir": "pytorch_compiled",
"num_outputs": 1,
"tuning_n_trials": 0, # set zero to skip tuning
"tuning_log_file": "tuning.log",
"target": "rocm --libs=miopen,rocblas",
"device": tvm.rocm(),
}
pytorch_tvm_module = compile(model_jit, option)
preds = pytorch_tvm_module.forward([x])
preds = pytorch_tvm_module.forward([x])
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
with torch.no_grad():
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(image)
print("print ===orig output =====",output)
output = pytorch_tvm_module.forward([image])
print("print ===output =====",output)
print("print ===========target======",target)
loss = criterion(output, target)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
# FIXME need to take into account that the datasets
# could have been padded in distributed setup
batch_size = image.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5))
return metric_logger.acc1.global_avg
def _get_cache_path(filepath):
import hashlib
h = hashlib.sha1(filepath.encode()).hexdigest()
cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
cache_path = os.path.expanduser(cache_path)
return cache_path
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224)
print("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path)
else:
auto_augment_policy = getattr(args, "auto_augment", None)
random_erase_prob = getattr(args, "random_erase", 0.0)
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob))
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st)
print("Loading validation data")
cache_path = _get_cache_path(valdir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
else:
dataset_test = torchvision.datasets.ImageFolder(
valdir,
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size))
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)
print("Creating data loaders")
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
return dataset, dataset_test, train_sampler, test_sampler
def main(args):
if args.apex and amp is None:
raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training.")
if args.output_dir:
utils.mkdir(args.output_dir)
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size,
sampler=test_sampler, num_workers=args.workers, pin_memory=True)
print("Creating model")
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
criterion = nn.CrossEntropyLoss()
opt_name = args.opt.lower()
if opt_name == 'sgd':
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif opt_name == 'rmsprop':
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
else:
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
if args.apex:
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.apex_opt_level
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
return
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex)
lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device)
if args.output_dir:
checkpoint = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training', add_help=add_help)
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset')
parser.add_argument('--model', default='resnet18', help='model')
parser.add_argument('--device', default='cuda', help='device')
parser.add_argument('-b', '--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument(
"--cache-dataset",
dest="cache_dataset",
help="Cache the datasets for quicker initialization. It also serializes the transforms",
action="store_true",
)
parser.add_argument(
"--sync-bn",
dest="sync_bn",
help="Use sync batch norm",
action="store_true",
)
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true",
)
parser.add_argument(
"--pretrained",
dest="pretrained",
help="Use pre-trained models from the modelzoo",
action="store_true",
)
parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)')
parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)')
# Mixed precision training parameters
parser.add_argument('--apex', action='store_true',
help='Use apex for mixed precision training')
parser.add_argument('--apex-opt-level', default='O1', type=str,
help='For apex mixed precision training'
'O0 for FP32 training, O1 for mixed precision training.'
'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
)
# distributed training parameters
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)
import datetime
import os
import time
import numpy as np
import onnx
import torch
import torch.utils.data
from torch import nn
import torchvision
import tvm
from tvm.contrib.torch import compile
from tvm import relay
from tvm.contrib import graph_executor
import presets
import utils
try:
from apex import amp
except ImportError:
amp = None
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
#metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))
header = 'Epoch: [{}]'.format(epoch)
for image, target in metric_logger.log_every(data_loader, print_freq, header):
start_time = time.time()
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
if apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = image.shape[0]
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
def evaluate(model, criterion, data_loader, device, print_freq=100):
model.eval()
# 转换为TorchScript格式
input_shape = (1, 3, 224, 224)
input_tensor = torch.randn(input_shape).to(torch.device("cuda"))
# 导出ONNX文件
output_path = "./mobilenet.onnx"
torch.onnx.export(model, input_tensor, output_path)
# 将TorchScript模型保存为.pt文件
img_data = np.random.rand(1,3,224,224).astype("float32")/255
input_name = "input.1"
shape_dict = {input_name: img_data.shape}
model_path = './mobilenet.onnx'
onnx_model = onnx.load(model_path)
# Define the neural network and compilation target
batch_size = 1
layout = "NCHW"
target = "rocm -libs=miopen,rocblas"
dtype = "float32"
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, dtype=dtype)
# Compile with the history best
print("Compile...")
#mod = tvm.relay.transform.ToMixedPrecision(mixed_precision_type='float16', missing_op_mode=1)(mod)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
print('Compile success!')
dev = tvm.device(str(target), 0)
m = graph_executor.GraphModule(lib["default"](dev))
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
with torch.no_grad():
for image, target in metric_logger.log_every(data_loader, print_freq, header):
# image = image.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# output = model(image)
m.set_input("input.1", image.numpy())
m.run()
out = m.get_output(0).asnumpy()
output = torch.from_numpy(out).to(torch.device("cuda"))
loss = criterion(output, target)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
# FIXME need to take into account that the datasets
# could have been padded in distributed setup
batch_size = image.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5))
return metric_logger.acc1.global_avg
def _get_cache_path(filepath):
import hashlib
h = hashlib.sha1(filepath.encode()).hexdigest()
cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
cache_path = os.path.expanduser(cache_path)
return cache_path
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224)
print("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path)
else:
auto_augment_policy = getattr(args, "auto_augment", None)
random_erase_prob = getattr(args, "random_erase", 0.0)
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob))
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st)
print("Loading validation data")
cache_path = _get_cache_path(valdir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
else:
dataset_test = torchvision.datasets.ImageFolder(
valdir,
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size))
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)
print("Creating data loaders")
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
return dataset, dataset_test, train_sampler, test_sampler
def main(args):
if args.apex and amp is None:
raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training.")
if args.output_dir:
utils.mkdir(args.output_dir)
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size,
sampler=test_sampler, num_workers=args.workers, pin_memory=True)
print("Creating model")
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
criterion = nn.CrossEntropyLoss()
opt_name = args.opt.lower()
if opt_name == 'sgd':
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif opt_name == 'rmsprop':
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
else:
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
if args.apex:
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.apex_opt_level
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
return
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex)
lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device)
if args.output_dir:
checkpoint = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training', add_help=add_help)
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset')
parser.add_argument('--model', default='resnet18', help='model')
parser.add_argument('--device', default='cuda', help='device')
parser.add_argument('-b', '--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument(
"--cache-dataset",
dest="cache_dataset",
help="Cache the datasets for quicker initialization. It also serializes the transforms",
action="store_true",
)
parser.add_argument(
"--sync-bn",
dest="sync_bn",
help="Use sync batch norm",
action="store_true",
)
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true",
)
parser.add_argument(
"--pretrained",
dest="pretrained",
help="Use pre-trained models from the modelzoo",
action="store_true",
)
parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)')
parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)')
# Mixed precision training parameters
parser.add_argument('--apex', action='store_true',
help='Use apex for mixed precision training')
parser.add_argument('--apex-opt-level', default='O1', type=str,
help='For apex mixed precision training'
'O0 for FP32 training, O1 for mixed precision training.'
'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
)
# distributed training parameters
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment