Commit b8d3ff26 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add multi.py

parent b7a31755
import os
from fitlog import FitLog
from ShuffleNet.model import ShuffleNetV2
from iframe_feeder import iframe_feeder
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm, trange
from torch.optim.lr_scheduler import MultiStepLR
import torch.nn as nn
class ShuffleNetV2Driver:
def __init__(self, featd, feati, normd, normi, lab):
self.batchsize = 256
self.lr = 0.01
self.momentum = 0.9
self.decay = 4e-5
self.gamma = 0.1
self.schedule = [200, 300]
self.local_rank = int(os.getenv("LOCAL_RANK", -1))
self.RANK = int(os.getenv("RANK", -1))
self.feeder = iframe_feeder(featd, feati, lab, normd, normi)
self.feeder.set_mode('train')
if self.local_rank >= 0:
self.device = torch.device('cuda', self.local_rank)
else:
self.device = torch.device('cuda')
if self.local_rank >= 0:
self.sampler = torch.utils.data.distributed.DistributedSampler(self.feeder)
self.loader = DataLoader(self.feeder, batch_size=self.batchsize,sampler=self.sampler, shuffle=False)
else:
self.loader = DataLoader(self.feeder, batch_size=self.batchsize, shuffle=True, num_workers=0)
#self.feeder = iframe_feeder(featd, feati, lab, normd, normi)
#self.feeder.set_mode('train')
self.model = ShuffleNetV2(num_classes=2, scale=0.5, SE=True, residual=True)#torchvision.models.ShuffleNetV2(num_classes=2)#
self.fitlog = FitLog()
self.detail_log = FitLog(prefix='dt_')
#self.device = self.get_device() #torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device info: " + str(self.device))
self.optimizer = torch.optim.SGD(self.model.parameters(),\
self.lr, momentum=self.momentum, weight_decay=self.decay,\
nesterov=True)
self.scheduler = MultiStepLR(self.optimizer, self.schedule, self.gamma)
self.criterion = torch.nn.CrossEntropyLoss()
#self.loader = DataLoader(self.feeder, batch_size=self.batchsize, shuffle=True, num_workers=0)
self.print_interval = 25
self.n_epoch = 200
print('self_device:',self.device)
print('self_local_rank:',self.local_rank)
self.model.to(self.device)
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[self.local_rank],output_device=self.local_rank, find_unused_parameters=True)
def get_device(self):
device_mlu = None
device_gpu = None
try:
# device_mlu = torch.device('mlu')
device_gpu = torch.device('cuda')
except Exception as err:
print(err)
if device_mlu:
self.fitlog.append('mlu', True, True)
return device_mlu
elif device_gpu:
self.fitlog.append('cuda', True, True)
return device_gpu
else:
self.fitlog.append('cpu', True, True)
return torch.device('cpu')
def train(self, epoch):
self.feeder.set_mode('train')
self.model.train()
nbatch = len(self.loader)
print('nbatch:',nbatch)
for batch_idx, (feats, labs) in enumerate(tqdm(self.loader)):
feats = feats.to(self.device)
labs = labs.to(self.device)
self.optimizer.zero_grad()
res = self.model(feats)
loss = self.criterion(res, labs)
loss.backward()
self.optimizer.step()
if batch_idx % self.print_interval == 0:
xstr = "Train: epoch: {} batch: {}/{}, loss: {:.6f}".format(epoch, batch_idx, nbatch, loss)
tqdm.write(xstr)
self.fitlog.append(xstr, True, True)
def validate(self):
self.feeder.set_mode('test')
self.model.eval()
loss_val = 0
n_correct = 0
n_total = 0
for batch_idx, (feats, labs) in enumerate(tqdm(self.loader)):
feats = feats.to(self.device)
labs = labs.to(self.device)
with torch.no_grad():
res = self.model(feats)
loss_val += self.criterion(res, labs).item()
_, pred = res.max(1)
n_correct += pred.eq(labs).sum().item()
n_total += labs.shape[0]
loss_val = loss_val / len(self.loader)
acc = n_correct / n_total * 100
self.detail_log.append(str(labs.tolist()))
self.detail_log.append(str(pred.tolist()))
xstr = "Validation: avg loss: {:.4f}, avg acc: {:.4f}%".format(loss_val, acc)
tqdm.write(xstr)
self.fitlog.append(xstr, True, True)
def finish(self):
self.feeder.finish()
self.fitlog.close()
self.detail_log.close()
def run(self):
for i in range(1, self.n_epoch + 1):
self.train(i)
self.validate()
if __name__ == "__main__":
datafolder = "data/"
torch.distributed.init_process_group(backend="nccl")
driver = ShuffleNetV2Driver(datafolder + "s2_ftimgd",\
datafolder + "s2_ftimgi",\
datafolder + "s2_normd",\
datafolder + "s2_normi",\
datafolder + "s2_label.json")
driver.run()
driver.finish()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment