import torch import torch.nn as nn # from peselibs_config import get_lib_path import sys # sys.path.append(get_lib_path()) import DDP import torchvision.models.mobilenet as mobilenet from datawork import * from sklearn.metrics import accuracy_score from fitlog import FitLog from torch.utils.data import DataLoader import time import random g_dubug = False class MobileNetV2Driver(): def __init__(self, local_rank):#DDP: system initialization self.nclass = 9 self.batch_size = 64 self.local_rank = local_rank self.nepoch = 500 self.nround = 10 self.lr = 0.00001 self.loader = None self.test_loader = None self.dataset = None self.device = None #model & device self.model = mobilenet.MobileNetV2(num_classes=self.nclass) # print("local_rank:{}".format(local_rank)) self._init_device() self.model.to(self.device) if self.local_rank != None: self.model = nn.parallel.DistributedDataParallel( self.model, device_ids=[self.local_rank], find_unused_parameters=True) self.criterion = nn.CrossEntropyLoss() # self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr) self.optimizer = torch.optim.Adam(self.model.parameters(),lr=0.001,betas=(0.9,0.999)) #self.scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=1, gamma=0.98) self.dataset = Fer2013Dataset(local_rank) try: self.sampler = torch.utils.data.distributed.DistributedSampler(self.dataset) except: self.sampler=None if self.local_rank != None: self.loader = DataLoader( self.dataset, batch_size=self.batch_size,sampler=self.sampler, shuffle=False) else: self.loader = DataLoader( self.dataset, batch_size=self.batch_size, shuffle=True) self.test_loader = DataLoader( self.dataset, batch_size=self.batch_size,shuffle=True) def _init_device(self): if self.local_rank != None: self.device = torch.device('cuda', self.local_rank) else: if torch.cuda.is_available(): self.device = torch.device('cuda') else: self.device = torch.device('cpu') def init_dataset(self, seed): self.dataset.randomization(seed) def train(self): best_acc = 0 best_acc_at = 0 if self.local_rank == 0 or self.local_rank == None: self.fitlog = FitLog("./logs/") self.jishilog = FitLog("./logs/",prefix='jishi') self.dlog = FitLog("./logs/", prefix='pred') st_time=time.time() for epoch in range(self.nepoch): self.dataset.set_mode("train") self.model.train() all_loss = [] for batch_idx, (data, target) in enumerate(self.loader): data, target = data.to(self.device), target.to(self.device) self.dataset.set_mode("train") ##################train time if self.local_rank == 0 or self.local_rank == None: jishi1=time.time() self.model.train() self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) loss.backward() self.optimizer.step() if self.local_rank == 0 or self.local_rank == None: jishi2=time.time() jishi2_log = '****epc:{},process:{}/{},best_acc:{},start:{},end:{},duration:{}****'.format(str(epoch), str(batch_idx * len(data)), str(len(self.loader.dataset)), str(best_acc),str(jishi1), str(jishi2),str(jishi2 - jishi1)) jishi2_log=str(jishi2_log) self.jishilog.append(jishi2_log) print(jishi2_log) #########################train time all_loss.append(loss.item()) t1=time.time() duration=t1-st_time if (batch_idx % 10 == 0) and (self.local_rank == 0 or self.local_rank == None): btstr = 'epc: {} [{}/{} ({:.0f}%)] loss: {:.6f} b-acc: {:.3f} @:{},curtime:{},duration:{}'.format( epoch, batch_idx * len(data), len(self.loader.dataset), 100. * batch_idx / len(self.loader), loss.item(), best_acc, best_acc_at,str(t1),str(duration)) self.fitlog.append(btstr) # print(btstr) if g_dubug: break torch.save(self.model,'./mobilenet.pth') if self.local_rank == 0 or self.local_rank == None: t1=time.time() duration=t1-st_time acc, vloss, vloss_std,all_pred, all_tar = self._validate() epcstr = '****epc:{},loss:{:.6f},loss_std:{:.6f},vloss:{:.6f},vloss_std:{:.6f},acc:{:.3f},duration:{}****'.format( epoch, np.mean(all_loss), np.std(all_loss), vloss, vloss_std,acc,str(duration)) #self.dlog.append(epcstr+",preds:{},plabs:{}".format(str(all_pred), str(all_tar))) self.dlog.append(epcstr) self.dlog.append("pred"+str(all_pred)) self.dlog.append("tar"+str(all_tar)) if acc > best_acc: best_acc = acc best_acc_at = epoch print(epcstr) if g_dubug: break if self.local_rank == 0 or self.local_rank == None: self.fitlog.close() self.dlog.close() self.jishilog.close() def _validate(self): self.model.eval() self.dataset.set_mode('test') all_pred = [] all_tar = [] accs = [] all_loss = [] with torch.no_grad(): for i, (ft, labs) in enumerate(self.test_loader): ft, labs = ft.to(self.device), labs.to(self.device) output = self.model(ft) loss = self.criterion(output, labs) preds = torch.argmax(output, dim=1).cpu().numpy().tolist() all_pred.extend(preds) all_tar.extend(labs.cpu().numpy().tolist()) accs.append(accuracy_score(all_tar, all_pred)) all_loss.append(loss.item()) if i % 100 == 0: print('validating @ batch {}'.format(i)) if g_dubug: break return np.mean(accs), np.mean(all_loss),np.std(all_loss), all_pred, all_tar def run(self, iround): self.init_dataset(iround) self.train() if __name__ == '__main__': print(torch.cuda.is_available()) ############ seed = 0 torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) ############## t0 =time.time() local_rank = DDP.init_ddp() print("local_rank=",local_rank) driver = MobileNetV2Driver(local_rank=local_rank) # print("round {}".format(sys.argv[1])) iround=1 driver.run(iround) t1 =time.time() print("result_time=",(t1-t0)/1000)