Unverified Commit da01c234 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Develop/experiments (#59)



* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>

* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000

* Integrate 1d tensor parallel in Colossal-AI (#39)

* fixed 1D and 2D convergence (#38)

* optimized 2D operations

* fixed 1D ViT convergence problem

* Feature/ddp (#49)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* support torch ddp

* fix loss accumulation

* add log for ddp

* change seed

* modify timing hook
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* Feature/pipeline (#40)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* optimize communication of pipeline parallel

* fix grad clip for pipeline
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)

* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset

* update api for better usability (#58)

update api for better usability
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>
Co-authored-by: default avatarpuck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>
Co-authored-by: default avatarアマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: default avatarBoxiangW <45734921+BoxiangW@users.noreply.github.com>
parent eb2f8b1f
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pytest
import torch
import torchvision.models as models
from colossalai.builder import build_model
NUM_CLS = 10
RESNET18 = dict(
type='VanillaResNet',
block_type='ResNetBasicBlock',
layers=[2, 2, 2, 2],
num_cls=NUM_CLS
)
RESNET34 = dict(
type='VanillaResNet',
block_type='ResNetBasicBlock',
layers=[3, 4, 6, 3],
num_cls=NUM_CLS
)
RESNET50 = dict(
type='VanillaResNet',
block_type='ResNetBottleneck',
layers=[3, 4, 6, 3],
num_cls=NUM_CLS
)
RESNET101 = dict(
type='VanillaResNet',
block_type='ResNetBottleneck',
layers=[3, 4, 23, 3],
num_cls=NUM_CLS
)
RESNET152 = dict(
type='VanillaResNet',
block_type='ResNetBottleneck',
layers=[3, 8, 36, 3],
num_cls=NUM_CLS
)
def compare_model(data, colossal_model, torchvision_model):
colossal_output = colossal_model(data)
torchvision_output = torchvision_model(data)
assert colossal_output[
0].shape == torchvision_output.shape, f'{colossal_output[0].shape}, {torchvision_output.shape}'
@pytest.mark.cpu
def test_vanilla_resnet():
"""Compare colossal resnet with torchvision resnet"""
# data
x = torch.randn((2, 3, 224, 224))
# resnet 18
col_resnet18 = build_model(RESNET18)
col_resnet18.build_from_cfg()
torchvision_resnet18 = models.resnet18(num_classes=NUM_CLS)
compare_model(x, col_resnet18, torchvision_resnet18)
# resnet 34
col_resnet34 = build_model(RESNET34)
col_resnet34.build_from_cfg()
torchvision_resnet34 = models.resnet34(num_classes=NUM_CLS)
compare_model(x, col_resnet34, torchvision_resnet34)
# resnet 50
col_resnet50 = build_model(RESNET50)
col_resnet50.build_from_cfg()
torchvision_resnet50 = models.resnet50(num_classes=NUM_CLS)
compare_model(x, col_resnet50, torchvision_resnet50)
# resnet 101
col_resnet101 = build_model(RESNET101)
col_resnet101.build_from_cfg()
torchvision_resnet101 = models.resnet101(num_classes=NUM_CLS)
compare_model(x, col_resnet101, torchvision_resnet101)
# # resnet 152
col_resnet152 = build_model(RESNET152)
col_resnet152.build_from_cfg()
torchvision_resnet152 = models.resnet152(num_classes=NUM_CLS)
compare_model(x, col_resnet152, torchvision_resnet152)
if __name__ == '__main__':
test_vanilla_resnet()
import os
from pathlib import Path
BATCH_SIZE = 512
IMG_SIZE = 32
PATCH_SIZE = 4
DIM = 512
NUM_ATTENTION_HEADS = 8
SUMMA_DIM = 2
NUM_CLASSES = 10
DEPTH = 6
train_data = dict(
dataset=dict(
type='CIFAR10Dataset',
root=Path(os.environ['DATA']),
transform_pipeline=[
dict(type='RandomCrop', size=IMG_SIZE, padding=4),
dict(type='RandomHorizontalFlip'),
dict(type='ToTensor'),
dict(type='Normalize',
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]),
]),
dataloader=dict(batch_size=BATCH_SIZE,
pin_memory=True,
num_workers=4,
shuffle=True))
test_data = dict(
dataset=dict(
type='CIFAR10Dataset',
root=Path(os.environ['DATA']),
train=False,
transform_pipeline=[
dict(type='Resize', size=IMG_SIZE),
dict(type='ToTensor'),
dict(type='Normalize',
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]),
]),
dataloader=dict(batch_size=400,
pin_memory=True,
num_workers=4,
shuffle=True))
optimizer = dict(type='Adam', lr=0.001, weight_decay=0)
loss = dict(type='CrossEntropyLoss2D', )
model = dict(
type='VisionTransformerFromConfig',
tensor_splitting_cfg=dict(type='ViTInputSplitter2D', ),
embedding_cfg=dict(
type='ViTPatchEmbedding2D',
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=DIM,
),
token_fusion_cfg=dict(type='ViTTokenFuser2D',
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=DIM,
drop_rate=0.1),
norm_cfg=dict(
type='LayerNorm2D',
normalized_shape=DIM,
eps=1e-6,
),
block_cfg=dict(
type='ViTBlock',
attention_cfg=dict(
type='ViTSelfAttention2D',
hidden_size=DIM,
num_attention_heads=NUM_ATTENTION_HEADS,
attention_dropout_prob=0.,
hidden_dropout_prob=0.1,
),
droppath_cfg=dict(type='VanillaViTDropPath', ),
mlp_cfg=dict(type='ViTMLP2D',
in_features=DIM,
dropout_prob=0.1,
mlp_ratio=1),
norm_cfg=dict(
type='LayerNorm2D',
normalized_shape=DIM,
eps=1e-6,
),
),
head_cfg=dict(
type='ViTHead2D',
hidden_size=DIM,
num_classes=NUM_CLASSES,
),
embed_dim=DIM,
depth=DEPTH,
drop_path_rate=0.,
)
parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'),
)
num_epochs = 60
lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5, total_steps=num_epochs)
import os
from pathlib import Path
BATCH_SIZE = 512
IMG_SIZE = 32
PATCH_SIZE = 4
DIM = 512
NUM_ATTENTION_HEADS = 8
SUMMA_DIM = 2
NUM_CLASSES = 10
DEPTH = 6
train_data = dict(
dataset=dict(
type='CIFAR10Dataset',
root=Path(os.environ['DATA']),
transform_pipeline=[
dict(type='RandomCrop', size=IMG_SIZE, padding=4),
dict(type='RandomHorizontalFlip'),
dict(type='ToTensor'),
dict(type='Normalize',
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]),
]
),
dataloader=dict(
batch_size=BATCH_SIZE,
pin_memory=True,
num_workers=0,
shuffle=True
)
)
test_data = dict(
dataset=dict(
type='CIFAR10Dataset',
root=Path(os.environ['DATA']),
train=False,
transform_pipeline=[
dict(type='Resize', size=IMG_SIZE),
dict(type='ToTensor'),
dict(type='Normalize',
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]
),
]
),
dataloader=dict(
batch_size=400,
pin_memory=True,
num_workers=0,
shuffle=True
)
)
optimizer = dict(
type='Adam',
lr=0.001,
weight_decay=0
)
loss = dict(
type='CrossEntropyLoss2p5D',
)
model = dict(
type='VisionTransformerFromConfig',
tensor_splitting_cfg=dict(
type='ViTInputSplitter2p5D',
),
embedding_cfg=dict(
type='ViTPatchEmbedding2p5D',
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=DIM,
),
token_fusion_cfg=dict(
type='ViTTokenFuser2p5D',
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=DIM,
drop_rate=0.1
),
norm_cfg=dict(
type='LayerNorm2p5D',
normalized_shape=DIM,
eps=1e-6,
),
block_cfg=dict(
type='ViTBlock',
attention_cfg=dict(
type='ViTSelfAttention2p5D',
hidden_size=DIM,
num_attention_heads=NUM_ATTENTION_HEADS,
attention_dropout_prob=0.,
hidden_dropout_prob=0.1,
),
droppath_cfg=dict(
type='VanillaViTDropPath',
),
mlp_cfg=dict(
type='ViTMLP2p5D',
in_features=DIM,
dropout_prob=0.1,
mlp_ratio=1
),
norm_cfg=dict(
type='LayerNorm2p5D',
normalized_shape=DIM,
eps=1e-6,
),
),
head_cfg=dict(
type='ViTHead2p5D',
hidden_size=DIM,
num_classes=NUM_CLASSES,
),
embed_dim=DIM,
depth=DEPTH,
drop_path_rate=0.,
)
parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=4, depth=1, mode='2.5d'),
)
num_epochs = 60
lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5, total_steps=num_epochs)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
from pathlib import Path
from colossalai.context import ParallelMode
IMG_SIZE = 32
PATCH_SIZE = 4
EMBED_SIZE = 512
HIDDEN_SIZE = 512
NUM_HEADS = 8
NUM_CLASSES = 10
NUM_BLOCKS = 6
DROP_RATE = 0.1
BATCH_SIZE = 512
LEARNING_RATE = 0.001
DATASET_PATH = Path(os.environ['DATA'])
model = dict(
type='VisionTransformerFromConfig',
embedding_cfg=dict(
type='ViTPatchEmbedding3D',
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
in_chans=3,
embed_size=EMBED_SIZE,
drop_prob=DROP_RATE,
),
block_cfg=dict(
type='ViTBlock',
norm_cfg=dict(
type='LayerNorm3D',
normalized_shape=HIDDEN_SIZE,
eps=1e-6,
input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT,
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT,
),
attention_cfg=dict(
type='ViTSelfAttention3D',
hidden_size=HIDDEN_SIZE,
num_attention_heads=NUM_HEADS,
attention_probs_dropout_prob=0.,
hidden_dropout_prob=DROP_RATE,
),
droppath_cfg=dict(type='VanillaViTDropPath', ),
mlp_cfg=dict(
type='ViTMLP3D',
hidden_size=HIDDEN_SIZE,
mlp_ratio=1,
hidden_dropout_prob=DROP_RATE,
hidden_act='gelu',
),
),
norm_cfg=dict(type='LayerNorm3D',
normalized_shape=HIDDEN_SIZE,
eps=1e-6,
input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT,
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT),
head_cfg=dict(
type='ViTHead3D',
in_features=HIDDEN_SIZE,
num_classes=NUM_CLASSES,
),
embed_dim=HIDDEN_SIZE,
depth=NUM_BLOCKS,
drop_path_rate=0.,
)
loss = dict(type='CrossEntropyLoss3D',
input_parallel_mode=ParallelMode.PARALLEL_3D_OUTPUT,
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT,
reduction=True)
optimizer = dict(type='Adam', lr=LEARNING_RATE, weight_decay=0)
train_data = dict(dataset=dict(type='CIFAR10Dataset',
root=DATASET_PATH,
transform_pipeline=[
dict(type='RandomCrop',
size=IMG_SIZE,
padding=4),
dict(type='RandomHorizontalFlip'),
dict(type='ToTensor'),
dict(type='Normalize',
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]),
]),
dataloader=dict(batch_size=BATCH_SIZE,
pin_memory=True,
shuffle=True,
num_workers=8))
test_data = dict(dataset=dict(type='CIFAR10Dataset',
root=DATASET_PATH,
train=False,
transform_pipeline=[
dict(type='Resize', size=IMG_SIZE),
dict(type='ToTensor'),
dict(type='Normalize',
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]),
]),
dataloader=dict(batch_size=400,
pin_memory=True,
num_workers=8))
hooks = [
dict(type='LogMetricByEpochHook'),
dict(type='LogTimingByEpochHook'),
dict(type='LogMemoryByEpochHook'),
dict(
type='Accuracy3DHook',
input_parallel_mode=ParallelMode.PARALLEL_3D_OUTPUT,
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT,
),
dict(type='LossHook'),
dict(
type='LRSchedulerHook',
by_epoch=True,
lr_scheduler_cfg=dict(
type='LinearWarmupLR',
warmup_steps=5
)
),
]
parallel = dict(
data=1,
pipeline=1,
tensor=dict(mode='3d', size=8),
)
num_epochs = 60
import torch.nn as nn
IMG_SIZE = 224
DIM = 768
NUM_CLASSES = 1000
NUM_ATTN_HEADS = 12
model = dict(
type='VisionTransformerFromConfig',
embedding_cfg=dict(
type='VanillaViTPatchEmbedding',
img_size=IMG_SIZE,
patch_size=16,
in_chans=3,
embed_dim=DIM
),
norm_cfg=dict(
type='LayerNorm',
eps=1e-6,
normalized_shape=DIM
),
block_cfg=dict(
type='ViTBlock',
checkpoint=True,
attention_cfg=dict(
type='VanillaViTAttention',
dim=DIM,
num_heads=NUM_ATTN_HEADS,
qkv_bias=True,
attn_drop=0.,
proj_drop=0.
),
droppath_cfg=dict(
type='VanillaViTDropPath',
),
mlp_cfg=dict(
type='VanillaViTMLP',
in_features=DIM,
hidden_features=DIM * 4,
act_layer=nn.GELU,
drop=0.
),
norm_cfg=dict(
type='LayerNorm',
normalized_shape=DIM
),
),
head_cfg=dict(
type='VanillaViTHead',
in_features=DIM,
intermediate_features=DIM * 2,
out_features=NUM_CLASSES
),
depth=12,
drop_path_rate=0.,
)
#!/usr/bin/env sh
test_file=$1
python $test_file --local_rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
Tue Aug 31 23:19:11 CDT 2021
TACC: Starting up job 3475503
TACC: Starting parallel tasks...
warning: variables which starts with __, is a module or class declaration are omitted
process rank 0 is bound to device 0
distributed environment is initialzied
model is created
warning: variables which starts with __, is a module or class declaration are omitted
process rank 3 is bound to device 3
Files already downloaded and verified
Files already downloaded and verified
warning: variables which starts with __, is a module or class declaration are omitted
process rank 2 is bound to device 2
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
training and testing dataloaders are created
loss is created
optimizer is created
start training
warning: variables which starts with __, is a module or class declaration are omitted
process rank 1 is bound to device 1
Files already downloaded and verified
Files already downloaded and verified
epoch: 0, train loss: 1.9477875804414555
epoch: 0, eval loss: 1.8044581711292267, correct: 3584, total: 10000, acc = 0.35839998722076416
epoch: 1, train loss: 1.6170539077447386
epoch: 2, train loss: 1.4285206673096638
epoch: 2, eval loss: 1.3448221325874328, correct: 5167, total: 10000, acc = 0.5166999697685242
epoch: 3, train loss: 1.3227316007322194
epoch: 4, train loss: 1.2758926104526132
epoch: 4, eval loss: 1.2780393660068512, correct: 5460, total: 10000, acc = 0.5460000038146973
epoch: 5, train loss: 1.2221618829941263
epoch: 6, train loss: 1.1857815640313285
epoch: 6, eval loss: 1.1175921618938447, correct: 6023, total: 10000, acc = 0.6022999882698059
epoch: 7, train loss: 1.1659576710389585
epoch: 8, train loss: 1.1457150134505059
epoch: 8, eval loss: 1.0789835333824158, correct: 6113, total: 10000, acc = 0.611299991607666
epoch: 9, train loss: 1.1156543700062498
epoch: 10, train loss: 1.0950473242876482
epoch: 10, eval loss: 1.058586174249649, correct: 6170, total: 10000, acc = 0.6169999837875366
epoch: 11, train loss: 1.0976360866001673
epoch: 12, train loss: 1.07803391193857
epoch: 12, eval loss: 1.0039635241031646, correct: 6351, total: 10000, acc = 0.6351000070571899
epoch: 13, train loss: 1.0680764615535736
epoch: 14, train loss: 1.0364759442757587
epoch: 14, eval loss: 0.9748250603675842, correct: 6486, total: 10000, acc = 0.6485999822616577
epoch: 15, train loss: 1.023898609438721
epoch: 16, train loss: 0.9982165353638786
epoch: 16, eval loss: 0.9612966269254685, correct: 6591, total: 10000, acc = 0.6590999960899353
epoch: 17, train loss: 0.9698412771127662
epoch: 18, train loss: 0.9523191050607331
epoch: 18, eval loss: 0.8974281877279282, correct: 6810, total: 10000, acc = 0.6809999942779541
epoch: 19, train loss: 0.9171817661548147
epoch: 20, train loss: 0.8905259948603961
epoch: 20, eval loss: 0.8580602705478668, correct: 6965, total: 10000, acc = 0.6965000033378601
epoch: 21, train loss: 0.86673782917918
epoch: 22, train loss: 0.8339344001546198
epoch: 22, eval loss: 0.8263293951749802, correct: 7107, total: 10000, acc = 0.7106999754905701
epoch: 23, train loss: 0.8074834510988119
epoch: 24, train loss: 0.7840324482139276
epoch: 24, eval loss: 0.752952727675438, correct: 7317, total: 10000, acc = 0.7317000031471252
epoch: 25, train loss: 0.7541018596717289
epoch: 26, train loss: 0.7357191905683401
epoch: 26, eval loss: 0.7338999301195145, correct: 7410, total: 10000, acc = 0.7409999966621399
epoch: 27, train loss: 0.7107210451242875
epoch: 28, train loss: 0.6785972909051545
epoch: 28, eval loss: 0.7020785599946976, correct: 7523, total: 10000, acc = 0.752299964427948
epoch: 29, train loss: 0.660102152094549
epoch: 30, train loss: 0.6498027924372225
epoch: 30, eval loss: 0.6610008627176285, correct: 7661, total: 10000, acc = 0.7660999894142151
epoch: 31, train loss: 0.6297167344969146
epoch: 32, train loss: 0.6150159224563715
epoch: 32, eval loss: 0.6350889533758164, correct: 7802, total: 10000, acc = 0.7802000045776367
epoch: 33, train loss: 0.5912032842027898
epoch: 34, train loss: 0.5761601137263435
epoch: 34, eval loss: 0.6296706795692444, correct: 7786, total: 10000, acc = 0.7785999774932861
epoch: 35, train loss: 0.5586571322411907
epoch: 36, train loss: 0.5488096165413759
epoch: 36, eval loss: 0.6041992783546448, correct: 7913, total: 10000, acc = 0.7912999987602234
epoch: 37, train loss: 0.5273334958723613
epoch: 38, train loss: 0.5074144468015555
epoch: 38, eval loss: 0.5868680268526077, correct: 7984, total: 10000, acc = 0.7983999848365784
epoch: 39, train loss: 0.4930413775906271
epoch: 40, train loss: 0.47384805977344513
epoch: 40, eval loss: 0.6013937592506409, correct: 7945, total: 10000, acc = 0.7944999933242798
epoch: 41, train loss: 0.4618621742238804
epoch: 42, train loss: 0.4452754973757024
epoch: 42, eval loss: 0.5606920897960663, correct: 8093, total: 10000, acc = 0.8093000054359436
epoch: 43, train loss: 0.4361336164328517
epoch: 44, train loss: 0.4188923318775333
epoch: 44, eval loss: 0.5567828729748726, correct: 8042, total: 10000, acc = 0.8041999936103821
epoch: 45, train loss: 0.4047189655960823
epoch: 46, train loss: 0.3873833852763079
epoch: 46, eval loss: 0.5404785141348839, correct: 8166, total: 10000, acc = 0.81659996509552
epoch: 47, train loss: 0.3707445412874222
epoch: 48, train loss: 0.3631058514726405
epoch: 48, eval loss: 0.5541519388556481, correct: 8201, total: 10000, acc = 0.820099949836731
epoch: 49, train loss: 0.34395075604623676
epoch: 50, train loss: 0.3290589987015238
epoch: 50, eval loss: 0.5442438080906868, correct: 8169, total: 10000, acc = 0.8168999552726746
epoch: 51, train loss: 0.3188562990755451
epoch: 52, train loss: 0.2986554713273535
epoch: 52, eval loss: 0.5515974283218383, correct: 8203, total: 10000, acc = 0.8202999830245972
epoch: 53, train loss: 0.29044121671087886
epoch: 54, train loss: 0.27310980613134345
epoch: 54, eval loss: 0.5587902516126633, correct: 8195, total: 10000, acc = 0.8194999694824219
epoch: 55, train loss: 0.2637303553673686
epoch: 56, train loss: 0.2521531299060705
epoch: 56, eval loss: 0.5885633528232574, correct: 8202, total: 10000, acc = 0.8201999664306641
epoch: 57, train loss: 0.23304983274060853
epoch: 58, train loss: 0.22784664101746618
epoch: 58, eval loss: 0.5882876932621002, correct: 8245, total: 10000, acc = 0.8244999647140503
epoch: 59, train loss: 0.21604956868959932
epoch: 60, train loss: 0.20325114882113982
epoch: 60, eval loss: 0.5910753712058068, correct: 8248, total: 10000, acc = 0.8247999548912048
epoch: 61, train loss: 0.19390226033877353
epoch: 62, train loss: 0.18323212360240976
epoch: 62, eval loss: 0.6264512360095977, correct: 8272, total: 10000, acc = 0.8271999955177307
epoch: 63, train loss: 0.1680474430322647
epoch: 64, train loss: 0.16121925512442783
epoch: 64, eval loss: 0.640467157959938, correct: 8283, total: 10000, acc = 0.8282999992370605
epoch: 65, train loss: 0.14981685054241395
epoch: 66, train loss: 0.14731310475237516
epoch: 66, eval loss: 0.6354441046714783, correct: 8303, total: 10000, acc = 0.830299973487854
epoch: 67, train loss: 0.13300996874364054
epoch: 68, train loss: 0.12739452506814683
epoch: 68, eval loss: 0.6673313498497009, correct: 8282, total: 10000, acc = 0.8281999826431274
epoch: 69, train loss: 0.11627298555507952
epoch: 70, train loss: 0.10940710728874012
epoch: 70, eval loss: 0.692647397518158, correct: 8302, total: 10000, acc = 0.8301999568939209
epoch: 71, train loss: 0.10183572788171623
epoch: 72, train loss: 0.09634554986746943
epoch: 72, eval loss: 0.695426219701767, correct: 8299, total: 10000, acc = 0.8298999667167664
epoch: 73, train loss: 0.09228058896806775
epoch: 74, train loss: 0.08581420976896675
epoch: 74, eval loss: 0.694861987233162, correct: 8340, total: 10000, acc = 0.8339999914169312
epoch: 75, train loss: 0.07914869715364611
epoch: 76, train loss: 0.0742536057547039
epoch: 76, eval loss: 0.7130348771810532, correct: 8347, total: 10000, acc = 0.8346999883651733
epoch: 77, train loss: 0.06935026941402835
epoch: 78, train loss: 0.0665280031306403
epoch: 78, eval loss: 0.7465721786022186, correct: 8341, total: 10000, acc = 0.8341000080108643
epoch: 79, train loss: 0.05928862589050313
epoch: 80, train loss: 0.05455683164146482
epoch: 80, eval loss: 0.776301947236061, correct: 8314, total: 10000, acc = 0.8313999772071838
epoch: 81, train loss: 0.05638634926658504
epoch: 82, train loss: 0.05360411343221762
epoch: 82, eval loss: 0.7883096963167191, correct: 8332, total: 10000, acc = 0.8331999778747559
epoch: 83, train loss: 0.04867944630737207
epoch: 84, train loss: 0.0474467751931171
epoch: 84, eval loss: 0.7960963994264603, correct: 8340, total: 10000, acc = 0.8339999914169312
epoch: 85, train loss: 0.044186076149344444
epoch: 86, train loss: 0.041499203527156185
epoch: 86, eval loss: 0.7910951197147369, correct: 8364, total: 10000, acc = 0.8363999724388123
epoch: 87, train loss: 0.03865447499770291
epoch: 88, train loss: 0.03929391104195799
epoch: 88, eval loss: 0.8027831196784974, correct: 8371, total: 10000, acc = 0.8370999693870544
epoch: 89, train loss: 0.03619915826664287
epoch: 90, train loss: 0.034103386617284646
epoch: 90, eval loss: 0.8021850943565368, correct: 8357, total: 10000, acc = 0.8356999754905701
epoch: 91, train loss: 0.037686211741244306
epoch: 92, train loss: 0.033469487000636906
epoch: 92, eval loss: 0.8143919885158539, correct: 8359, total: 10000, acc = 0.8359000086784363
epoch: 93, train loss: 0.03238337778733397
epoch: 94, train loss: 0.03141044888987529
epoch: 94, eval loss: 0.8093269526958465, correct: 8385, total: 10000, acc = 0.8384999632835388
epoch: 95, train loss: 0.031111840225223984
epoch: 96, train loss: 0.032168653251945366
epoch: 96, eval loss: 0.8102991491556167, correct: 8379, total: 10000, acc = 0.8378999829292297
epoch: 97, train loss: 0.03043455306478605
epoch: 98, train loss: 0.03200419174925405
epoch: 98, eval loss: 0.8105081558227539, correct: 8373, total: 10000, acc = 0.8373000025749207
epoch: 99, train loss: 0.031662615329711416
finish training
TACC: Shutdown complete. Exiting.
Tue Aug 31 12:28:41 CDT 2021
TACC: Starting up job 3472937
TACC: Starting parallel tasks...
warning: variables which starts with __, is a module or class declaration are omitted
process rank 0 is bound to device 0
distributed environment is initialzied
model is created
warning: variables which starts with __, is a module or class declaration are omitted
process rank 3 is bound to device 3
Files already downloaded and verified
Files already downloaded and verified
warning: variables which starts with __, is a module or class declaration are omitted
process rank 2 is bound to device 2
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
training and testing dataloaders are created
loss is created
warning: variables which starts with __, is a module or class declaration are omitted
process rank 1 is bound to device 1
Files already downloaded and verified
Files already downloaded and verified
optimizer is created
start training
epoch: 0, train loss: 2.0869219929901597
epoch: 0, eval loss: 1.9415993988513947, correct: 2875, total: 10000, acc = 0.2874999940395355
epoch: 1, train loss: 1.832751432644952
epoch: 2, train loss: 1.6953342194409715
epoch: 2, eval loss: 1.6502822101116181, correct: 4026, total: 10000, acc = 0.4025999903678894
epoch: 3, train loss: 1.583214813900977
epoch: 4, train loss: 1.4921425851349979
epoch: 4, eval loss: 1.4688542783260345, correct: 4773, total: 10000, acc = 0.4772999882698059
epoch: 5, train loss: 1.3872919402171655
epoch: 6, train loss: 1.3123903028743784
epoch: 6, eval loss: 1.328972715139389, correct: 5275, total: 10000, acc = 0.5274999737739563
epoch: 7, train loss: 1.2454541012183906
epoch: 8, train loss: 1.1953427422906935
epoch: 8, eval loss: 1.2376527905464172, correct: 5579, total: 10000, acc = 0.5579000115394592
epoch: 9, train loss: 1.1491977308214325
epoch: 10, train loss: 1.1148795012346249
epoch: 10, eval loss: 1.1297606527805328, correct: 5975, total: 10000, acc = 0.5974999666213989
epoch: 11, train loss: 1.076469630310216
epoch: 12, train loss: 1.0476364874348199
epoch: 12, eval loss: 1.029269078373909, correct: 6333, total: 10000, acc = 0.6333000063896179
epoch: 13, train loss: 1.0117879393174476
epoch: 14, train loss: 0.9859390357106003
epoch: 14, eval loss: 0.97474505007267, correct: 6494, total: 10000, acc = 0.649399995803833
epoch: 15, train loss: 0.9595183336857668
epoch: 16, train loss: 0.9384779051407096
epoch: 16, eval loss: 0.9172703564167023, correct: 6716, total: 10000, acc = 0.6715999841690063
epoch: 17, train loss: 0.9127772370564569
epoch: 18, train loss: 0.889132705545917
epoch: 18, eval loss: 0.8939311623573303, correct: 6809, total: 10000, acc = 0.680899977684021
epoch: 19, train loss: 0.8719241373317758
epoch: 20, train loss: 0.8456196920158937
epoch: 20, eval loss: 0.8584266930818558, correct: 6944, total: 10000, acc = 0.6943999528884888
epoch: 21, train loss: 0.8258345379042871
epoch: 22, train loss: 0.8185748826597155
epoch: 22, eval loss: 0.8427778095006943, correct: 7020, total: 10000, acc = 0.7019999623298645
epoch: 23, train loss: 0.794703829534275
epoch: 24, train loss: 0.777785701235545
epoch: 24, eval loss: 0.801164984703064, correct: 7182, total: 10000, acc = 0.7181999683380127
epoch: 25, train loss: 0.760752295095896
epoch: 26, train loss: 0.7453707229230822
epoch: 26, eval loss: 0.7841533124446869, correct: 7209, total: 10000, acc = 0.7208999991416931
epoch: 27, train loss: 0.7267675215436011
epoch: 28, train loss: 0.7131575210807249
epoch: 28, eval loss: 0.7708685129880906, correct: 7254, total: 10000, acc = 0.7253999710083008
epoch: 29, train loss: 0.7007347437524304
epoch: 30, train loss: 0.6834727574869529
epoch: 30, eval loss: 0.7591335833072662, correct: 7356, total: 10000, acc = 0.7355999946594238
epoch: 31, train loss: 0.6712760894568925
epoch: 32, train loss: 0.655675129177644
epoch: 32, eval loss: 0.7655339151620865, correct: 7314, total: 10000, acc = 0.7313999533653259
epoch: 33, train loss: 0.6421149262447947
epoch: 34, train loss: 0.6301654601834484
epoch: 34, eval loss: 0.7450480967760086, correct: 7350, total: 10000, acc = 0.73499995470047
epoch: 35, train loss: 0.6189313580080406
epoch: 36, train loss: 0.6047559282214371
epoch: 36, eval loss: 0.7468931972980499, correct: 7392, total: 10000, acc = 0.7391999959945679
epoch: 37, train loss: 0.5878085592358383
epoch: 38, train loss: 0.5731440121980057
epoch: 38, eval loss: 0.7349929332733154, correct: 7434, total: 10000, acc = 0.743399977684021
epoch: 39, train loss: 0.5633921856732712
epoch: 40, train loss: 0.5499549056451345
epoch: 40, eval loss: 0.7258913427591324, correct: 7483, total: 10000, acc = 0.7482999563217163
epoch: 41, train loss: 0.5403583102005044
epoch: 42, train loss: 0.5286270272485989
epoch: 42, eval loss: 0.7170430123806, correct: 7528, total: 10000, acc = 0.7527999877929688
epoch: 43, train loss: 0.5166667195939526
epoch: 44, train loss: 0.5098928068716502
epoch: 44, eval loss: 0.7244295090436935, correct: 7531, total: 10000, acc = 0.7530999779701233
epoch: 45, train loss: 0.4917362458312634
epoch: 46, train loss: 0.48251094676784634
epoch: 46, eval loss: 0.728115001320839, correct: 7557, total: 10000, acc = 0.7556999921798706
epoch: 47, train loss: 0.47845434067175563
epoch: 48, train loss: 0.4637242813700253
epoch: 48, eval loss: 0.7259155690670014, correct: 7559, total: 10000, acc = 0.755899965763092
epoch: 49, train loss: 0.4557308668328315
epoch: 50, train loss: 0.4414065560114752
epoch: 50, eval loss: 0.7056828439235687, correct: 7648, total: 10000, acc = 0.764799952507019
epoch: 51, train loss: 0.43054875792916286
epoch: 52, train loss: 0.4196087404624703
epoch: 52, eval loss: 0.7131796926259995, correct: 7670, total: 10000, acc = 0.7669999599456787
epoch: 53, train loss: 0.41613124971537246
epoch: 54, train loss: 0.4016842494920357
epoch: 54, eval loss: 0.7215427845716477, correct: 7641, total: 10000, acc = 0.7640999555587769
epoch: 55, train loss: 0.39098499054761277
epoch: 56, train loss: 0.3805114098430909
epoch: 56, eval loss: 0.7281092345714569, correct: 7672, total: 10000, acc = 0.7671999931335449
epoch: 57, train loss: 0.3724562412070245
epoch: 58, train loss: 0.37037558162335266
epoch: 58, eval loss: 0.7282122701406479, correct: 7707, total: 10000, acc = 0.7706999778747559
epoch: 59, train loss: 0.3584493641386327
epoch: 60, train loss: 0.35091825858833864
epoch: 60, eval loss: 0.7441833585500717, correct: 7653, total: 10000, acc = 0.7652999758720398
epoch: 61, train loss: 0.3469349926279992
epoch: 62, train loss: 0.33631756533052504
epoch: 62, eval loss: 0.7398703306913376, correct: 7679, total: 10000, acc = 0.7678999900817871
epoch: 63, train loss: 0.3287597510618033
epoch: 64, train loss: 0.3201854192104536
epoch: 64, eval loss: 0.7452850311994552, correct: 7676, total: 10000, acc = 0.7675999999046326
epoch: 65, train loss: 0.3196122018025093
epoch: 66, train loss: 0.30826768893556494
epoch: 66, eval loss: 0.7634350836277009, correct: 7637, total: 10000, acc = 0.763700008392334
epoch: 67, train loss: 0.30273781855081777
epoch: 68, train loss: 0.29493943732423883
epoch: 68, eval loss: 0.7755635917186737, correct: 7679, total: 10000, acc = 0.7678999900817871
epoch: 69, train loss: 0.2938486310010104
epoch: 70, train loss: 0.2860709805156767
epoch: 70, eval loss: 0.7754869312047958, correct: 7690, total: 10000, acc = 0.7689999938011169
epoch: 71, train loss: 0.27915918873142953
epoch: 72, train loss: 0.2728954671891694
epoch: 72, eval loss: 0.78065524995327, correct: 7693, total: 10000, acc = 0.7692999839782715
epoch: 73, train loss: 0.2666821759386161
epoch: 74, train loss: 0.2651688180018946
epoch: 74, eval loss: 0.7740301787853241, correct: 7739, total: 10000, acc = 0.7738999724388123
epoch: 75, train loss: 0.2613655937086676
epoch: 76, train loss: 0.2508497520820382
epoch: 76, eval loss: 0.7777235865592956, correct: 7777, total: 10000, acc = 0.7777000069618225
epoch: 77, train loss: 0.2466566213934692
epoch: 78, train loss: 0.24763428181717076
epoch: 78, eval loss: 0.7764069586992264, correct: 7765, total: 10000, acc = 0.7764999866485596
epoch: 79, train loss: 0.24774935457509817
epoch: 80, train loss: 0.23876388401714796
epoch: 80, eval loss: 0.7927991092205048, correct: 7740, total: 10000, acc = 0.7739999890327454
epoch: 81, train loss: 0.23702984618157455
epoch: 82, train loss: 0.2349560634069836
epoch: 82, eval loss: 0.794422161579132, correct: 7734, total: 10000, acc = 0.7734000086784363
epoch: 83, train loss: 0.2302393423220546
epoch: 84, train loss: 0.22696133203727684
epoch: 84, eval loss: 0.7939992696046829, correct: 7768, total: 10000, acc = 0.7767999768257141
epoch: 85, train loss: 0.22804359692273682
epoch: 86, train loss: 0.21921918088013365
epoch: 86, eval loss: 0.792869821190834, correct: 7768, total: 10000, acc = 0.7767999768257141
epoch: 87, train loss: 0.22169384437123524
epoch: 88, train loss: 0.21990801271089574
epoch: 88, eval loss: 0.7874982714653015, correct: 7761, total: 10000, acc = 0.7760999798774719
epoch: 89, train loss: 0.2174029858763685
epoch: 90, train loss: 0.21091166722405816
epoch: 90, eval loss: 0.7935442298650741, correct: 7765, total: 10000, acc = 0.7764999866485596
epoch: 91, train loss: 0.214123952788176
epoch: 92, train loss: 0.2140680920217455
epoch: 92, eval loss: 0.7855452030897141, correct: 7787, total: 10000, acc = 0.7786999940872192
epoch: 93, train loss: 0.2146580269963471
epoch: 94, train loss: 0.2091474826495672
epoch: 94, eval loss: 0.7892280638217926, correct: 7783, total: 10000, acc = 0.7782999873161316
epoch: 95, train loss: 0.21155885015566325
epoch: 96, train loss: 0.21236088549353413
epoch: 96, eval loss: 0.7883010923862457, correct: 7785, total: 10000, acc = 0.778499960899353
epoch: 97, train loss: 0.20943679852583974
epoch: 98, train loss: 0.20945941495526696
epoch: 98, eval loss: 0.7873147040605545, correct: 7783, total: 10000, acc = 0.7782999873161316
epoch: 99, train loss: 0.2085563373012641
finish training
TACC: Shutdown complete. Exiting.
Wed Sep 1 01:07:01 CDT 2021
TACC: Starting up job 3476018
TACC: Starting parallel tasks...
warning: variables which starts with __, is a module or class declaration are omitted
process rank 0 is bound to device 0
distributed environment is initialzied
model is created
Files already downloaded and verified
Files already downloaded and verified
training and testing dataloaders are created
loss is created
optimizer is created
start training
epoch: 0, train loss: 1.9497510997616514
epoch: 0, eval loss: 1.754234939813614, correct: 3521, total: 10000, acc = 0.3520999848842621
epoch: 1, train loss: 1.6049139609142227
epoch: 2, train loss: 1.3857794501343552
epoch: 2, eval loss: 1.2831632316112518, correct: 5410, total: 10000, acc = 0.5410000085830688
epoch: 3, train loss: 1.3016913873808724
epoch: 4, train loss: 1.2616293649284207
epoch: 4, eval loss: 1.2658930838108062, correct: 5409, total: 10000, acc = 0.5408999919891357
epoch: 5, train loss: 1.2320433721250417
epoch: 6, train loss: 1.181612290898148
epoch: 6, eval loss: 1.1402096092700957, correct: 5881, total: 10000, acc = 0.5880999565124512
epoch: 7, train loss: 1.1643818397911228
epoch: 8, train loss: 1.128499301112428
epoch: 8, eval loss: 1.0965303361415863, correct: 6053, total: 10000, acc = 0.6053000092506409
epoch: 9, train loss: 1.114193707704544
epoch: 10, train loss: 1.0830892950904614
epoch: 10, eval loss: 1.0390974164009095, correct: 6258, total: 10000, acc = 0.6258000135421753
epoch: 11, train loss: 1.0508871960396668
epoch: 12, train loss: 1.0322130365031106
epoch: 12, eval loss: 0.9689173698425293, correct: 6482, total: 10000, acc = 0.6481999754905701
epoch: 13, train loss: 1.0006194637746226
epoch: 14, train loss: 0.9652800906677635
epoch: 14, eval loss: 0.9150958389043808, correct: 6713, total: 10000, acc = 0.6712999939918518
epoch: 15, train loss: 0.9430981692002744
epoch: 16, train loss: 0.9156872307767674
epoch: 16, eval loss: 0.8703682094812393, correct: 6913, total: 10000, acc = 0.6912999749183655
epoch: 17, train loss: 0.8822251515729087
epoch: 18, train loss: 0.8485424190151448
epoch: 18, eval loss: 0.8234190821647644, correct: 7120, total: 10000, acc = 0.7119999527931213
epoch: 19, train loss: 0.8285953049757042
epoch: 20, train loss: 0.8009484337300671
epoch: 20, eval loss: 0.7808267176151276, correct: 7228, total: 10000, acc = 0.7227999567985535
epoch: 21, train loss: 0.7774611741912608
epoch: 22, train loss: 0.7435575358721674
epoch: 22, eval loss: 0.7523189872503281, correct: 7367, total: 10000, acc = 0.7366999983787537
epoch: 23, train loss: 0.7315681789602552
epoch: 24, train loss: 0.70117900627
epoch: 24, eval loss: 0.6928718358278274, correct: 7580, total: 10000, acc = 0.7579999566078186
epoch: 25, train loss: 0.677533069435431
epoch: 26, train loss: 0.6627033298112908
epoch: 26, eval loss: 0.6921748876571655, correct: 7586, total: 10000, acc = 0.7585999965667725
epoch: 27, train loss: 0.6410714266251545
epoch: 28, train loss: 0.6192339707394036
epoch: 28, eval loss: 0.6416671514511109, correct: 7719, total: 10000, acc = 0.7718999981880188
epoch: 29, train loss: 0.6093639281331277
epoch: 30, train loss: 0.582532714520182
epoch: 30, eval loss: 0.6166591048240662, correct: 7809, total: 10000, acc = 0.7809000015258789
epoch: 31, train loss: 0.572193189847226
epoch: 32, train loss: 0.5541256200902316
epoch: 32, eval loss: 0.5951347410678863, correct: 7922, total: 10000, acc = 0.792199969291687
epoch: 33, train loss: 0.5345369838938421
epoch: 34, train loss: 0.5273816007740644
epoch: 34, eval loss: 0.5837202191352844, correct: 7972, total: 10000, acc = 0.7971999645233154
epoch: 35, train loss: 0.5059237045292951
epoch: 36, train loss: 0.48622317095192114
epoch: 36, eval loss: 0.5698897138237953, correct: 8024, total: 10000, acc = 0.8023999929428101
epoch: 37, train loss: 0.47362951143663756
epoch: 38, train loss: 0.46030426907296085
epoch: 38, eval loss: 0.5610475659370422, correct: 8049, total: 10000, acc = 0.8048999905586243
epoch: 39, train loss: 0.44165324921510657
epoch: 40, train loss: 0.4327346086502075
epoch: 40, eval loss: 0.5642214670777321, correct: 8095, total: 10000, acc = 0.809499979019165
epoch: 41, train loss: 0.41423581935921494
epoch: 42, train loss: 0.40917488780556893
epoch: 42, eval loss: 0.5602998435497284, correct: 8131, total: 10000, acc = 0.8130999803543091
epoch: 43, train loss: 0.39171184477757437
epoch: 44, train loss: 0.3744060835059808
epoch: 44, eval loss: 0.5633655220270157, correct: 8134, total: 10000, acc = 0.8133999705314636
epoch: 45, train loss: 0.36267226934432983
epoch: 46, train loss: 0.3420030690577565
epoch: 46, eval loss: 0.5533872425556183, correct: 8157, total: 10000, acc = 0.8156999945640564
epoch: 47, train loss: 0.3287143409252167
epoch: 48, train loss: 0.316296321396925
epoch: 48, eval loss: 0.5576229721307755, correct: 8209, total: 10000, acc = 0.8208999633789062
epoch: 49, train loss: 0.3068045072105466
epoch: 50, train loss: 0.2929732614025778
epoch: 50, eval loss: 0.5654072970151901, correct: 8227, total: 10000, acc = 0.8226999640464783
epoch: 51, train loss: 0.2795026940958841
epoch: 52, train loss: 0.26673941375041493
epoch: 52, eval loss: 0.5736668109893799, correct: 8227, total: 10000, acc = 0.8226999640464783
epoch: 53, train loss: 0.2506744866164363
epoch: 54, train loss: 0.24351145980917677
epoch: 54, eval loss: 0.5846156671643257, correct: 8204, total: 10000, acc = 0.8203999996185303
epoch: 55, train loss: 0.2253616195248098
epoch: 56, train loss: 0.2177750574690955
epoch: 56, eval loss: 0.5943332687020302, correct: 8246, total: 10000, acc = 0.8245999813079834
epoch: 57, train loss: 0.20670234989755007
epoch: 58, train loss: 0.1973607996288611
epoch: 58, eval loss: 0.6195310011506081, correct: 8245, total: 10000, acc = 0.8244999647140503
epoch: 59, train loss: 0.19024320448539694
epoch: 60, train loss: 0.17597664877468225
epoch: 60, eval loss: 0.6139472931623459, correct: 8294, total: 10000, acc = 0.8294000029563904
epoch: 61, train loss: 0.1674150490791214
epoch: 62, train loss: 0.15718420511301684
epoch: 62, eval loss: 0.6285309329628944, correct: 8261, total: 10000, acc = 0.8260999917984009
epoch: 63, train loss: 0.1480691913439303
epoch: 64, train loss: 0.1384550367234921
epoch: 64, eval loss: 0.6587671056389809, correct: 8263, total: 10000, acc = 0.8262999653816223
epoch: 65, train loss: 0.13241269834795777
epoch: 66, train loss: 0.12871786830376605
epoch: 66, eval loss: 0.6718123883008957, correct: 8303, total: 10000, acc = 0.830299973487854
epoch: 67, train loss: 0.11577517866176001
epoch: 68, train loss: 0.11130036151378739
epoch: 68, eval loss: 0.6887702852487564, correct: 8332, total: 10000, acc = 0.8331999778747559
epoch: 69, train loss: 0.09883711646710124
epoch: 70, train loss: 0.09635799735480426
epoch: 70, eval loss: 0.7159708231687546, correct: 8307, total: 10000, acc = 0.8306999802589417
epoch: 71, train loss: 0.09449125119313902
epoch: 72, train loss: 0.08857650914210446
epoch: 72, eval loss: 0.7160102307796479, correct: 8351, total: 10000, acc = 0.835099995136261
epoch: 73, train loss: 0.08085554241373831
epoch: 74, train loss: 0.07873564483407809
epoch: 74, eval loss: 0.7119918942451477, correct: 8393, total: 10000, acc = 0.8392999768257141
epoch: 75, train loss: 0.07206312137446841
epoch: 76, train loss: 0.06772394200824962
epoch: 76, eval loss: 0.7328802436590195, correct: 8351, total: 10000, acc = 0.835099995136261
epoch: 77, train loss: 0.061777200397788265
epoch: 78, train loss: 0.05721901174710722
epoch: 78, eval loss: 0.7407010316848754, correct: 8385, total: 10000, acc = 0.8384999632835388
epoch: 79, train loss: 0.056560877406475495
epoch: 80, train loss: 0.0528045150318316
epoch: 80, eval loss: 0.7767532706260681, correct: 8354, total: 10000, acc = 0.8353999853134155
epoch: 81, train loss: 0.050682742870887934
epoch: 82, train loss: 0.04895328068915678
epoch: 82, eval loss: 0.7942879348993301, correct: 8368, total: 10000, acc = 0.8367999792098999
epoch: 83, train loss: 0.04686643050185272
epoch: 84, train loss: 0.04325723648071289
epoch: 84, eval loss: 0.7906839996576309, correct: 8356, total: 10000, acc = 0.835599958896637
epoch: 85, train loss: 0.040166335769605876
epoch: 86, train loss: 0.039296497894945194
epoch: 86, eval loss: 0.8033982694149018, correct: 8376, total: 10000, acc = 0.8375999927520752
epoch: 87, train loss: 0.038185219698566565
epoch: 88, train loss: 0.03735689769441984
epoch: 88, eval loss: 0.8039661139249802, correct: 8377, total: 10000, acc = 0.8376999497413635
epoch: 89, train loss: 0.03383794939145446
epoch: 90, train loss: 0.03318257091034736
epoch: 90, eval loss: 0.8097118645906448, correct: 8389, total: 10000, acc = 0.8388999700546265
epoch: 91, train loss: 0.03290939923109753
epoch: 92, train loss: 0.030776230903456405
epoch: 92, eval loss: 0.8237936168909072, correct: 8401, total: 10000, acc = 0.8400999903678894
epoch: 93, train loss: 0.033349379108344415
epoch: 94, train loss: 0.031906195783189366
epoch: 94, eval loss: 0.8250258564949036, correct: 8401, total: 10000, acc = 0.8400999903678894
epoch: 95, train loss: 0.03031293043334569
epoch: 96, train loss: 0.029958056238460904
epoch: 96, eval loss: 0.8200247555971145, correct: 8402, total: 10000, acc = 0.8402000069618225
epoch: 97, train loss: 0.029532150564981357
epoch: 98, train loss: 0.029668816346295025
epoch: 98, eval loss: 0.821219089627266, correct: 8399, total: 10000, acc = 0.8398999571800232
epoch: 99, train loss: 0.02980129667842875
finish training
TACC: Shutdown complete. Exiting.
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from pathlib import Path
import pytest
import torch.autograd
import colossalai
from colossalai.builder import build_lr_scheduler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.nn.layer._parallel_utilities import _gather
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
def eval(engine, test_dataloader):
engine.eval()
accumulated_loss = 0
correct_sum = 0
total_sum = 0
num_steps = len(test_dataloader)
data_iter = iter(test_dataloader)
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
accumulated_loss += loss.detach().cpu().numpy()
output = _gather(
output[0],
ParallelMode.PARALLEL_2D_ROW,
1
)
output = _gather(
output,
ParallelMode.PARALLEL_2D_COL,
0,
)
output = torch.argmax(output, dim=-1)
correct = torch.sum(label[0] == output)
correct_sum += correct
total_sum += label[0].size(0)
avg_loss = accumulated_loss / num_steps
return correct_sum, total_sum, avg_loss
def train(engine, train_dataloader, lr_scheduler):
engine.train()
accumulated_loss = 0
num_steps = len(train_dataloader)
data_iter = iter(train_dataloader)
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
accumulated_loss += loss.detach().cpu().numpy()
avg_loss = accumulated_loss / num_steps
lr_scheduler.step()
return avg_loss
@pytest.mark.dist
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2d_parallel_vision_transformer():
# init dist
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer)
logger = get_global_dist_logger()
logger.info('start training')
for epoch in range(gpc.config.num_epochs):
train_loss = train(engine, train_dataloader, lr_scheduler)
logger.info(f'epoch {epoch} - train loss: {train_loss}')
if epoch % 2 == 0:
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
logger.info(
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
if __name__ == '__main__':
test_2d_parallel_vision_transformer()
TACC: Starting up job 3498212
TACC: Starting parallel tasks...
warning: variables which starts with __, is a module or class declaration are omitted
process rank 0 is bound to device 0
distributed environment is initialzied
model is created
Files already downloaded and verified
Files already downloaded and verified
training and testing dataloaders are created
loss is created
optimizer is created
start training
epoch: 0, train loss: 1.9590576728995965
epoch: 1, train loss: 1.6275222167676808
epoch: 1, eval loss: 1.5277319371700286, correct: 4435, total: 10000, acc = 0.44349998235702515
epoch: 2, train loss: 1.4355541419009774
epoch: 3, train loss: 1.3253967445723864
epoch: 3, eval loss: 1.309086227416992, correct: 5283, total: 10000, acc = 0.5282999873161316
epoch: 4, train loss: 1.2578775298838714
epoch: 5, train loss: 1.2231916554120121
epoch: 5, eval loss: 1.1699816286563873, correct: 5695, total: 10000, acc = 0.5694999694824219
epoch: 6, train loss: 1.1872552669778162
epoch: 7, train loss: 1.1616783823285783
epoch: 7, eval loss: 1.069484794139862, correct: 6183, total: 10000, acc = 0.6182999610900879
epoch: 8, train loss: 1.1155579333402672
epoch: 9, train loss: 1.0878059365311448
epoch: 9, eval loss: 1.0522838592529298, correct: 6202, total: 10000, acc = 0.620199978351593
epoch: 10, train loss: 1.0780728623575093
epoch: 11, train loss: 1.0522098152004942
epoch: 11, eval loss: 1.0902862310409547, correct: 6148, total: 10000, acc = 0.614799976348877
epoch: 12, train loss: 1.0366473337825464
epoch: 13, train loss: 1.0067467458394108
epoch: 13, eval loss: 0.9696728616952897, correct: 6531, total: 10000, acc = 0.6530999541282654
epoch: 14, train loss: 0.9676224273078295
epoch: 15, train loss: 0.9494374029490412
epoch: 15, eval loss: 0.9511896312236786, correct: 6646, total: 10000, acc = 0.6645999550819397
epoch: 16, train loss: 0.9231320935852674
epoch: 17, train loss: 0.9023846679804276
epoch: 17, eval loss: 0.8728409796953202, correct: 6866, total: 10000, acc = 0.6865999698638916
epoch: 18, train loss: 0.8684309854799387
epoch: 19, train loss: 0.836099565637355
epoch: 19, eval loss: 0.8208363801240921, correct: 7091, total: 10000, acc = 0.7091000080108643
epoch: 20, train loss: 0.8285067890371595
epoch: 21, train loss: 0.7930980793067387
epoch: 21, eval loss: 0.7793890535831451, correct: 7235, total: 10000, acc = 0.7234999537467957
epoch: 22, train loss: 0.762698369366782
epoch: 23, train loss: 0.7376812471418964
epoch: 23, eval loss: 0.746866625547409, correct: 7340, total: 10000, acc = 0.7339999675750732
epoch: 24, train loss: 0.7071484223920472
epoch: 25, train loss: 0.6905171658311572
epoch: 25, eval loss: 0.6909466415643692, correct: 7526, total: 10000, acc = 0.7525999546051025
epoch: 26, train loss: 0.6608500091397033
epoch: 27, train loss: 0.65504517907999
epoch: 27, eval loss: 0.6612646311521531, correct: 7697, total: 10000, acc = 0.7696999907493591
epoch: 28, train loss: 0.6234641969203949
epoch: 29, train loss: 0.6107665622720913
epoch: 29, eval loss: 0.666494044661522, correct: 7704, total: 10000, acc = 0.7703999876976013
epoch: 30, train loss: 0.5875011883219894
epoch: 31, train loss: 0.5739485697478665
epoch: 31, eval loss: 0.6217960953712464, correct: 7828, total: 10000, acc = 0.7827999591827393
epoch: 32, train loss: 0.548510205684876
epoch: 33, train loss: 0.5237194764979032
epoch: 33, eval loss: 0.6254391580820083, correct: 7842, total: 10000, acc = 0.7841999530792236
epoch: 34, train loss: 0.5154265892140719
epoch: 35, train loss: 0.494700480176478
epoch: 35, eval loss: 0.5981663644313813, correct: 7963, total: 10000, acc = 0.7962999939918518
epoch: 36, train loss: 0.4785171020395902
epoch: 37, train loss: 0.46277919259606576
epoch: 37, eval loss: 0.6061880439519882, correct: 7958, total: 10000, acc = 0.795799970626831
epoch: 38, train loss: 0.4398626606075131
epoch: 39, train loss: 0.4206806777083144
epoch: 39, eval loss: 0.6158866941928863, correct: 7959, total: 10000, acc = 0.7958999872207642
epoch: 40, train loss: 0.40768756550185536
epoch: 41, train loss: 0.39494050035671313
epoch: 41, eval loss: 0.5725498422980309, correct: 8132, total: 10000, acc = 0.8131999969482422
epoch: 42, train loss: 0.3742571521778496
epoch: 43, train loss: 0.3583034301290707
epoch: 43, eval loss: 0.5765605017542839, correct: 8155, total: 10000, acc = 0.8154999613761902
epoch: 44, train loss: 0.3342630756752832
epoch: 45, train loss: 0.31316718063792404
epoch: 45, eval loss: 0.583588008582592, correct: 8199, total: 10000, acc = 0.8198999762535095
epoch: 46, train loss: 0.30922748148441315
epoch: 47, train loss: 0.2906164434187266
epoch: 47, eval loss: 0.5934860140085221, correct: 8143, total: 10000, acc = 0.814300000667572
epoch: 48, train loss: 0.2741488078419043
epoch: 49, train loss: 0.2597196321098172
epoch: 49, eval loss: 0.5978868633508683, correct: 8195, total: 10000, acc = 0.8194999694824219
epoch: 50, train loss: 0.2440016470393356
epoch: 51, train loss: 0.2293997729311184
epoch: 51, eval loss: 0.5915440261363983, correct: 8232, total: 10000, acc = 0.823199987411499
epoch: 52, train loss: 0.2132072006257213
epoch: 53, train loss: 0.19785404767917128
epoch: 53, eval loss: 0.6171442106366157, correct: 8258, total: 10000, acc = 0.8258000016212463
epoch: 54, train loss: 0.1838149410121295
epoch: 55, train loss: 0.17691133977199086
epoch: 55, eval loss: 0.623777586221695, correct: 8275, total: 10000, acc = 0.8274999856948853
epoch: 56, train loss: 0.16595362697024735
epoch: 57, train loss: 0.1531825682946614
epoch: 57, eval loss: 0.6466041743755341, correct: 8243, total: 10000, acc = 0.8242999911308289
epoch: 58, train loss: 0.14334788979316243
epoch: 59, train loss: 0.13799503377201605
epoch: 59, eval loss: 0.6496601745486259, correct: 8249, total: 10000, acc = 0.8248999714851379
finish training
c196-011[rtx](1013)$ bash ./test.sh 1 1 1 0.001
TACC: Starting up job 3503164
TACC: Starting parallel tasks...
warning: variables which starts with __, is a module or class declaration are omitted
process rank 0 is bound to device 0
distributed environment is initialzied
USE_VANILLA model
model is created
Files already downloaded and verified
Files already downloaded and verified
training and testing dataloaders are created
loss is created
optimizer is created
start training
epoch: 0, train loss: 1.9408839624755236
epoch: 0, eval loss: 1.7896566271781922, correct: 3488, total: 10000, acc = 0.34880000352859497
epoch time: 40.82966494560242
epoch: 1, train loss: 1.6500030257263962
epoch: 1, eval loss: 1.5464953780174255, correct: 4545, total: 10000, acc = 0.4544999897480011
epoch time: 40.01254224777222
epoch: 2, train loss: 1.422887429899099
epoch: 2, eval loss: 1.37536381483078, correct: 5074, total: 10000, acc = 0.5073999762535095
epoch time: 40.107905864715576
epoch: 3, train loss: 1.3217590207956276
epoch: 3, eval loss: 1.3036327004432677, correct: 5377, total: 10000, acc = 0.5376999974250793
epoch time: 40.12306189537048
epoch: 4, train loss: 1.262234352072891
epoch: 4, eval loss: 1.2568134129047395, correct: 5475, total: 10000, acc = 0.5475000143051147
epoch time: 40.10755228996277
epoch: 5, train loss: 1.2381379117771072
epoch: 5, eval loss: 1.1941023647785187, correct: 5676, total: 10000, acc = 0.5676000118255615
epoch time: 40.119303464889526
epoch: 6, train loss: 1.2061052650821453
epoch: 6, eval loss: 1.1313925206661224, correct: 5938, total: 10000, acc = 0.5938000082969666
epoch time: 40.07719683647156
epoch: 7, train loss: 1.1659562563409611
epoch: 7, eval loss: 1.125486546754837, correct: 5958, total: 10000, acc = 0.59579998254776
epoch time: 40.1702299118042
epoch: 8, train loss: 1.1378972846634534
epoch: 8, eval loss: 1.082760637998581, correct: 6102, total: 10000, acc = 0.6101999878883362
epoch time: 40.22099733352661
epoch: 9, train loss: 1.1073276430976635
epoch: 9, eval loss: 1.1077564001083373, correct: 6038, total: 10000, acc = 0.6037999987602234
epoch time: 40.1106858253479
epoch: 10, train loss: 1.087894769347444
epoch: 10, eval loss: 1.0400531351566316, correct: 6311, total: 10000, acc = 0.6310999989509583
epoch time: 40.20973324775696
epoch: 11, train loss: 1.0556547295074075
epoch: 11, eval loss: 1.0295817345380782, correct: 6359, total: 10000, acc = 0.6358999609947205
epoch time: 40.23791980743408
epoch: 12, train loss: 1.0299884901971232
epoch: 12, eval loss: 1.003737959265709, correct: 6380, total: 10000, acc = 0.6380000114440918
epoch time: 40.08779859542847
epoch: 13, train loss: 0.9972386627781148
epoch: 13, eval loss: 0.9707699298858643, correct: 6499, total: 10000, acc = 0.649899959564209
epoch time: 40.10878801345825
epoch: 14, train loss: 0.9784559072280417
epoch: 14, eval loss: 0.9253897607326508, correct: 6641, total: 10000, acc = 0.6640999913215637
epoch time: 40.13168978691101
epoch: 15, train loss: 0.9409253481699495
epoch: 15, eval loss: 0.9120320588350296, correct: 6759, total: 10000, acc = 0.6758999824523926
epoch time: 40.162830114364624
epoch: 16, train loss: 0.925923115136672
epoch: 16, eval loss: 0.8850776582956315, correct: 6870, total: 10000, acc = 0.6869999766349792
epoch time: 40.145774602890015
epoch: 17, train loss: 0.8923340841215484
epoch: 17, eval loss: 0.8570599347352982, correct: 6950, total: 10000, acc = 0.6949999928474426
epoch time: 40.18058943748474
epoch: 18, train loss: 0.8638542884466599
epoch: 18, eval loss: 0.838410159945488, correct: 6971, total: 10000, acc = 0.6970999836921692
epoch time: 40.110822439193726
epoch: 19, train loss: 0.8400422529298432
epoch: 19, eval loss: 0.8189669162034988, correct: 7097, total: 10000, acc = 0.7096999883651733
epoch time: 40.066970109939575
epoch: 20, train loss: 0.8072922752828015
epoch: 20, eval loss: 0.7772788077592849, correct: 7240, total: 10000, acc = 0.7239999771118164
epoch time: 40.045086145401
epoch: 21, train loss: 0.788195074821005
epoch: 21, eval loss: 0.7793144911527634, correct: 7261, total: 10000, acc = 0.726099967956543
epoch time: 40.05983781814575
epoch: 22, train loss: 0.7574447350842612
epoch: 22, eval loss: 0.7660320281982422, correct: 7272, total: 10000, acc = 0.7271999716758728
epoch time: 40.11693739891052
epoch: 23, train loss: 0.7402738150285215
epoch: 23, eval loss: 0.7264292597770691, correct: 7418, total: 10000, acc = 0.7418000102043152
epoch time: 40.18724513053894
epoch: 24, train loss: 0.7125097580102026
epoch: 24, eval loss: 0.7105035990476608, correct: 7506, total: 10000, acc = 0.7505999803543091
epoch time: 40.1254940032959
epoch: 25, train loss: 0.6900304744438249
epoch: 25, eval loss: 0.6911167114973068, correct: 7562, total: 10000, acc = 0.7561999559402466
epoch time: 40.103896617889404
epoch: 26, train loss: 0.6648721482072558
epoch: 26, eval loss: 0.6780407190322876, correct: 7624, total: 10000, acc = 0.7623999714851379
epoch time: 40.18161463737488
epoch: 27, train loss: 0.6446310062797702
epoch: 27, eval loss: 0.6820667266845704, correct: 7612, total: 10000, acc = 0.761199951171875
epoch time: 40.19018864631653
epoch: 28, train loss: 0.6262476389505425
epoch: 28, eval loss: 0.6506347745656967, correct: 7704, total: 10000, acc = 0.7703999876976013
epoch time: 40.23526978492737
epoch: 29, train loss: 0.5968854001590184
epoch: 29, eval loss: 0.6507940381765366, correct: 7727, total: 10000, acc = 0.7726999521255493
epoch time: 40.26889181137085
epoch: 30, train loss: 0.587430303194085
epoch: 30, eval loss: 0.6333519726991653, correct: 7788, total: 10000, acc = 0.7787999510765076
epoch time: 40.28285789489746
epoch: 31, train loss: 0.5701514035463333
epoch: 31, eval loss: 0.6348810195922852, correct: 7799, total: 10000, acc = 0.7798999547958374
epoch time: 40.199995040893555
epoch: 32, train loss: 0.5482188679125845
epoch: 32, eval loss: 0.6192457497119903, correct: 7833, total: 10000, acc = 0.78329998254776
epoch time: 40.270729780197144
epoch: 33, train loss: 0.534268391375639
epoch: 33, eval loss: 0.6381673783063888, correct: 7790, total: 10000, acc = 0.7789999842643738
epoch time: 40.36342120170593
epoch: 34, train loss: 0.5104483384258893
epoch: 34, eval loss: 0.6173199415206909, correct: 7867, total: 10000, acc = 0.7866999506950378
epoch time: 40.34266257286072
epoch: 35, train loss: 0.4968841674984718
epoch: 35, eval loss: 0.604002220928669, correct: 7916, total: 10000, acc = 0.7915999889373779
epoch time: 40.39444589614868
epoch: 36, train loss: 0.4773432207959039
epoch: 36, eval loss: 0.5884111285209656, correct: 7965, total: 10000, acc = 0.7964999675750732
epoch time: 40.40647268295288
epoch: 37, train loss: 0.4621481445370888
epoch: 37, eval loss: 0.5748852327466011, correct: 8047, total: 10000, acc = 0.8046999573707581
epoch time: 40.29281520843506
epoch: 38, train loss: 0.4431859048045411
epoch: 38, eval loss: 0.5874941781163215, correct: 7995, total: 10000, acc = 0.7994999885559082
epoch time: 40.40029954910278
epoch: 39, train loss: 0.4305852785402415
epoch: 39, eval loss: 0.5991648495197296, correct: 7972, total: 10000, acc = 0.7971999645233154
epoch time: 40.399904012680054
epoch: 40, train loss: 0.4092241589512144
epoch: 40, eval loss: 0.5725525215268135, correct: 8069, total: 10000, acc = 0.8068999648094177
epoch time: 40.32663059234619
epoch: 41, train loss: 0.39218547179990887
epoch: 41, eval loss: 0.5886161357164383, correct: 8068, total: 10000, acc = 0.8068000078201294
epoch time: 40.32424521446228
epoch: 42, train loss: 0.3773612398274091
epoch: 42, eval loss: 0.5762413635849952, correct: 8126, total: 10000, acc = 0.8125999569892883
epoch time: 40.44430422782898
epoch: 43, train loss: 0.3593267098981507
epoch: 43, eval loss: 0.5729024946689606, correct: 8107, total: 10000, acc = 0.810699999332428
epoch time: 40.488121032714844
epoch: 44, train loss: 0.3396431426612698
epoch: 44, eval loss: 0.5944831907749176, correct: 8072, total: 10000, acc = 0.8071999549865723
epoch time: 40.41803979873657
epoch: 45, train loss: 0.32412939716358574
epoch: 45, eval loss: 0.5849291861057282, correct: 8171, total: 10000, acc = 0.8170999884605408
epoch time: 40.428131341934204
epoch: 46, train loss: 0.3099915471916296
epoch: 46, eval loss: 0.5797522723674774, correct: 8121, total: 10000, acc = 0.8120999932289124
epoch time: 40.623990058898926
epoch: 47, train loss: 0.29422828676749246
epoch: 47, eval loss: 0.5898703813552857, correct: 8175, total: 10000, acc = 0.8174999952316284
epoch time: 40.71224045753479
epoch: 48, train loss: 0.27581544600579205
epoch: 48, eval loss: 0.5950756087899208, correct: 8170, total: 10000, acc = 0.8169999718666077
epoch time: 40.53409385681152
epoch: 49, train loss: 0.26118586242807157
epoch: 49, eval loss: 0.5998703584074974, correct: 8213, total: 10000, acc = 0.8212999701499939
epoch time: 40.564385175704956
epoch: 50, train loss: 0.2513351797753451
epoch: 50, eval loss: 0.6011391341686249, correct: 8226, total: 10000, acc = 0.8226000070571899
epoch time: 40.55033254623413
epoch: 51, train loss: 0.22965944299892505
epoch: 51, eval loss: 0.5979882061481476, correct: 8233, total: 10000, acc = 0.8233000040054321
epoch time: 40.54532980918884
epoch: 52, train loss: 0.21661002188920975
epoch: 52, eval loss: 0.6121026620268821, correct: 8220, total: 10000, acc = 0.8219999670982361
epoch time: 40.649473667144775
epoch: 53, train loss: 0.20266114950788264
epoch: 53, eval loss: 0.6016955643892288, correct: 8260, total: 10000, acc = 0.8259999752044678
epoch time: 40.752054929733276
epoch: 54, train loss: 0.19287180794136866
epoch: 54, eval loss: 0.6043265879154205, correct: 8284, total: 10000, acc = 0.8283999562263489
epoch time: 40.68043255805969
epoch: 55, train loss: 0.175087109208107
epoch: 55, eval loss: 0.6146622076630592, correct: 8316, total: 10000, acc = 0.8315999507904053
epoch time: 40.58446717262268
epoch: 56, train loss: 0.16749868762432313
epoch: 56, eval loss: 0.6235148012638092, correct: 8313, total: 10000, acc = 0.8312999606132507
epoch time: 40.62826180458069
epoch: 57, train loss: 0.15567801619062618
epoch: 57, eval loss: 0.6325852945446968, correct: 8308, total: 10000, acc = 0.8307999968528748
epoch time: 40.72224497795105
epoch: 58, train loss: 0.1484297229623308
epoch: 58, eval loss: 0.6329193383455276, correct: 8325, total: 10000, acc = 0.8324999809265137
epoch time: 40.750558614730835
epoch: 59, train loss: 0.14238623818572688
epoch: 59, eval loss: 0.6318104699254036, correct: 8329, total: 10000, acc = 0.8328999876976013
epoch time: 40.77172636985779
finish training
\ No newline at end of file
TACC: Starting up job 3498663
TACC: Starting parallel tasks...
warning: variables which starts with __, is a module or class declaration are omitted
process rank 0 is bound to device 0
distributed environment is initialzied
model is created
Files already downloaded and verified
Files already downloaded and verified
training and testing dataloaders are created
loss is created
optimizer is created
start training
epoch: 0, train loss: 2.095031557034473
epoch: 1, train loss: 1.8454539605549403
epoch: 1, eval loss: 1.7768513083457946, correct: 3564, total: 10000, acc = 0.3563999831676483
epoch: 2, train loss: 1.7044833728245325
epoch: 3, train loss: 1.5999061124665397
epoch: 3, eval loss: 1.5574450254440309, correct: 4389, total: 10000, acc = 0.4388999938964844
epoch: 4, train loss: 1.4929670217085858
epoch: 5, train loss: 1.401450170546162
epoch: 5, eval loss: 1.4644017696380616, correct: 4857, total: 10000, acc = 0.48569998145103455
epoch: 6, train loss: 1.319102376091237
epoch: 7, train loss: 1.2555806539496597
epoch: 7, eval loss: 1.2475590467453004, correct: 5486, total: 10000, acc = 0.5485999584197998
epoch: 8, train loss: 1.1992503173497258
epoch: 9, train loss: 1.1600336493278036
epoch: 9, eval loss: 1.1786625683307648, correct: 5834, total: 10000, acc = 0.5834000110626221
epoch: 10, train loss: 1.1214540807568296
epoch: 11, train loss: 1.0808329728184913
epoch: 11, eval loss: 1.096825110912323, correct: 6072, total: 10000, acc = 0.6071999669075012
epoch: 12, train loss: 1.0521019423494533
epoch: 13, train loss: 1.0262362957000732
epoch: 13, eval loss: 1.056444275379181, correct: 6268, total: 10000, acc = 0.626800000667572
epoch: 14, train loss: 0.9932536555796253
epoch: 15, train loss: 0.9653559442685575
epoch: 15, eval loss: 0.9576991081237793, correct: 6582, total: 10000, acc = 0.6581999659538269
epoch: 16, train loss: 0.9465620943478176
epoch: 17, train loss: 0.9181081974992946
epoch: 17, eval loss: 0.9245584070682525, correct: 6747, total: 10000, acc = 0.6746999621391296
epoch: 18, train loss: 0.8987109752333894
epoch: 19, train loss: 0.8840238646585115
epoch: 19, eval loss: 0.8989996433258056, correct: 6787, total: 10000, acc = 0.6786999702453613
epoch: 20, train loss: 0.8591911811001447
epoch: 21, train loss: 0.843510093129411
epoch: 21, eval loss: 0.8595858901739121, correct: 6969, total: 10000, acc = 0.6969000101089478
epoch: 22, train loss: 0.8306782276046519
epoch: 23, train loss: 0.8181647640101763
epoch: 23, eval loss: 0.8600298583507537, correct: 7005, total: 10000, acc = 0.7005000114440918
epoch: 24, train loss: 0.7964763343334198
epoch: 25, train loss: 0.7840689718723297
epoch: 25, eval loss: 0.824479615688324, correct: 7073, total: 10000, acc = 0.7073000073432922
epoch: 26, train loss: 0.7709570752114666
epoch: 27, train loss: 0.7591698108887186
epoch: 27, eval loss: 0.7967212647199631, correct: 7196, total: 10000, acc = 0.7195999622344971
epoch: 28, train loss: 0.7438001352913526
epoch: 29, train loss: 0.7341659853653032
epoch: 29, eval loss: 0.8041222035884857, correct: 7168, total: 10000, acc = 0.7167999744415283
epoch: 30, train loss: 0.7254330929444761
epoch: 31, train loss: 0.710246913895315
epoch: 31, eval loss: 0.7848481118679047, correct: 7287, total: 10000, acc = 0.7286999821662903
epoch: 32, train loss: 0.6976562008565786
epoch: 33, train loss: 0.6906438475968887
epoch: 33, eval loss: 0.7644171923398971, correct: 7370, total: 10000, acc = 0.7369999885559082
epoch: 34, train loss: 0.6795850834067987
epoch: 35, train loss: 0.6724951656497254
epoch: 35, eval loss: 0.7515032321214676, correct: 7368, total: 10000, acc = 0.736799955368042
epoch: 36, train loss: 0.6527298372619006
epoch: 37, train loss: 0.651018523440069
epoch: 37, eval loss: 0.7381327033042908, correct: 7449, total: 10000, acc = 0.7448999881744385
epoch: 38, train loss: 0.6365304406808348
epoch: 39, train loss: 0.6372388047831399
epoch: 39, eval loss: 0.7342826008796692, correct: 7453, total: 10000, acc = 0.7452999949455261
epoch: 40, train loss: 0.6199644664112403
epoch: 41, train loss: 0.6101092303894005
epoch: 41, eval loss: 0.7353240340948105, correct: 7466, total: 10000, acc = 0.7465999722480774
epoch: 42, train loss: 0.6093496211937496
epoch: 43, train loss: 0.6019633388032719
epoch: 43, eval loss: 0.7350291252136231, correct: 7479, total: 10000, acc = 0.7479000091552734
epoch: 44, train loss: 0.5928211437196148
epoch: 45, train loss: 0.5840530048827736
epoch: 45, eval loss: 0.7301350146532058, correct: 7525, total: 10000, acc = 0.7524999976158142
epoch: 46, train loss: 0.578370426078232
epoch: 47, train loss: 0.5703256440405943
epoch: 47, eval loss: 0.7226948082447052, correct: 7526, total: 10000, acc = 0.7525999546051025
epoch: 48, train loss: 0.5622531275968162
epoch: 49, train loss: 0.5543749076979501
epoch: 49, eval loss: 0.7278151929378509, correct: 7536, total: 10000, acc = 0.753600001335144
epoch: 50, train loss: 0.5494355583677486
epoch: 51, train loss: 0.5427058047177841
epoch: 51, eval loss: 0.7180711388587951, correct: 7608, total: 10000, acc = 0.7608000040054321
epoch: 52, train loss: 0.5323820530760045
epoch: 53, train loss: 0.5341374232452742
epoch: 53, eval loss: 0.7136827558279037, correct: 7618, total: 10000, acc = 0.7617999911308289
epoch: 54, train loss: 0.5295403867351766
epoch: 55, train loss: 0.5226148692320804
epoch: 55, eval loss: 0.7158426463603973, correct: 7624, total: 10000, acc = 0.7623999714851379
epoch: 56, train loss: 0.5206544593888887
epoch: 57, train loss: 0.5186455438331682
epoch: 57, eval loss: 0.7141193479299546, correct: 7611, total: 10000, acc = 0.7610999941825867
epoch: 58, train loss: 0.5130856335163116
epoch: 59, train loss: 0.5103850683995655
epoch: 59, eval loss: 0.7077989399433136, correct: 7628, total: 10000, acc = 0.7627999782562256
finish training
c196-012[rtx](1006)$ bash ./test.sh 1 1 1 0.0001
TACC: Starting up job 3503177
TACC: Starting parallel tasks...
warning: variables which starts with __, is a module or class declaration are omitted
process rank 0 is bound to device 0
distributed environment is initialzied
USE_VANILLA model
model is created
Files already downloaded and verified
Files already downloaded and verified
training and testing dataloaders are created
loss is created
optimizer is created
start training
epoch: 0, train loss: 2.07912605757616
epoch: 0, eval loss: 1.9337591707706452, correct: 2845, total: 10000, acc = 0.28450000286102295
epoch time: 48.79993748664856
epoch: 1, train loss: 1.8506990890113675
epoch: 1, eval loss: 1.7832269430160523, correct: 3506, total: 10000, acc = 0.350600004196167
epoch time: 39.10968255996704
epoch: 2, train loss: 1.707400695401795
epoch: 2, eval loss: 1.6983122050762176, correct: 3935, total: 10000, acc = 0.3935000002384186
epoch time: 39.205119609832764
epoch: 3, train loss: 1.5925798574272467
epoch: 3, eval loss: 1.6361137092113496, correct: 4276, total: 10000, acc = 0.4275999963283539
epoch time: 39.220152378082275
epoch: 4, train loss: 1.4817699790000916
epoch: 4, eval loss: 1.4869949519634247, correct: 4706, total: 10000, acc = 0.4705999791622162
epoch time: 39.297648191452026
epoch: 5, train loss: 1.3685331247290786
epoch: 5, eval loss: 1.4110832333564758, correct: 5043, total: 10000, acc = 0.5042999982833862
epoch time: 39.31484127044678
epoch: 6, train loss: 1.283743022655954
epoch: 6, eval loss: 1.317776972055435, correct: 5320, total: 10000, acc = 0.5320000052452087
epoch time: 39.31891870498657
epoch: 7, train loss: 1.2292176107971036
epoch: 7, eval loss: 1.2397323846817017, correct: 5619, total: 10000, acc = 0.5618999600410461
epoch time: 39.31014013290405
epoch: 8, train loss: 1.1705418606193698
epoch: 8, eval loss: 1.2041720151901245, correct: 5696, total: 10000, acc = 0.569599986076355
epoch time: 39.29190945625305
epoch: 9, train loss: 1.1253369718181843
epoch: 9, eval loss: 1.1219275832176208, correct: 6039, total: 10000, acc = 0.6038999557495117
epoch time: 39.314892053604126
epoch: 10, train loss: 1.0875617825255102
epoch: 10, eval loss: 1.1398449420928956, correct: 5921, total: 10000, acc = 0.5920999646186829
epoch time: 39.29768466949463
epoch: 11, train loss: 1.055325626110544
epoch: 11, eval loss: 1.0739773243665696, correct: 6212, total: 10000, acc = 0.6211999654769897
epoch time: 39.26834416389465
epoch: 12, train loss: 1.0238730627663282
epoch: 12, eval loss: 1.0526267528533935, correct: 6244, total: 10000, acc = 0.6243999600410461
epoch time: 39.30522894859314
epoch: 13, train loss: 0.9906492087305808
epoch: 13, eval loss: 1.0342225402593612, correct: 6295, total: 10000, acc = 0.6294999718666077
epoch time: 39.28985071182251
epoch: 14, train loss: 0.968360669758855
epoch: 14, eval loss: 0.9747557610273361, correct: 6498, total: 10000, acc = 0.6498000025749207
epoch time: 39.33563685417175
epoch: 15, train loss: 0.9413909072778663
epoch: 15, eval loss: 0.9359912216663361, correct: 6659, total: 10000, acc = 0.6658999919891357
epoch time: 39.332377672195435
epoch: 16, train loss: 0.9215109226654987
epoch: 16, eval loss: 0.9215879321098328, correct: 6693, total: 10000, acc = 0.6692999601364136
epoch time: 39.35148882865906
epoch: 17, train loss: 0.9036085179873875
epoch: 17, eval loss: 0.8947311192750931, correct: 6787, total: 10000, acc = 0.6786999702453613
epoch time: 39.31995511054993
epoch: 18, train loss: 0.8774841433885147
epoch: 18, eval loss: 0.8880111247301101, correct: 6844, total: 10000, acc = 0.6843999624252319
epoch time: 39.32100558280945
epoch: 19, train loss: 0.8607137598553483
epoch: 19, eval loss: 0.8770220369100571, correct: 6883, total: 10000, acc = 0.6882999539375305
epoch time: 39.3321533203125
epoch: 20, train loss: 0.8482279163234088
epoch: 20, eval loss: 0.8661656975746155, correct: 6926, total: 10000, acc = 0.6926000118255615
epoch time: 39.319167613983154
epoch: 21, train loss: 0.8280732814146547
epoch: 21, eval loss: 0.8369802534580231, correct: 7041, total: 10000, acc = 0.7040999531745911
epoch time: 39.32543706893921
epoch: 22, train loss: 0.8162973212952517
epoch: 22, eval loss: 0.8281545102596283, correct: 7096, total: 10000, acc = 0.7095999717712402
epoch time: 39.344929695129395
epoch: 23, train loss: 0.8043988426120914
epoch: 23, eval loss: 0.8369941651821137, correct: 7070, total: 10000, acc = 0.7069999575614929
epoch time: 39.342397928237915
epoch: 24, train loss: 0.788704516328111
epoch: 24, eval loss: 0.8305304765701294, correct: 7040, total: 10000, acc = 0.7039999961853027
epoch time: 39.349589347839355
epoch: 25, train loss: 0.7747861517935383
epoch: 25, eval loss: 0.8025588423013688, correct: 7164, total: 10000, acc = 0.7163999676704407
epoch time: 39.35692596435547
epoch: 26, train loss: 0.7557641073149077
epoch: 26, eval loss: 0.7929455429315567, correct: 7204, total: 10000, acc = 0.7203999757766724
epoch time: 39.36091661453247
epoch: 27, train loss: 0.7422851062550837
epoch: 27, eval loss: 0.7790816932916641, correct: 7249, total: 10000, acc = 0.7249000072479248
epoch time: 39.355828046798706
epoch: 28, train loss: 0.7305653861590794
epoch: 28, eval loss: 0.7937072366476059, correct: 7204, total: 10000, acc = 0.7203999757766724
epoch time: 39.3598473072052
epoch: 29, train loss: 0.719313730998915
epoch: 29, eval loss: 0.7657937437295914, correct: 7320, total: 10000, acc = 0.7319999933242798
epoch time: 39.353551626205444
epoch: 30, train loss: 0.7127084263733455
epoch: 30, eval loss: 0.7556168884038925, correct: 7341, total: 10000, acc = 0.7340999841690063
epoch time: 39.37097501754761
epoch: 31, train loss: 0.7044506967067719
epoch: 31, eval loss: 0.7438590109348298, correct: 7359, total: 10000, acc = 0.7358999848365784
epoch time: 39.37364745140076
epoch: 32, train loss: 0.6920064693810989
epoch: 32, eval loss: 0.7408553540706635, correct: 7419, total: 10000, acc = 0.7418999671936035
epoch time: 39.372353076934814
epoch: 33, train loss: 0.6790882920732304
epoch: 33, eval loss: 0.7541307628154754, correct: 7332, total: 10000, acc = 0.733199954032898
epoch time: 39.310251235961914
epoch: 34, train loss: 0.6666433202977083
epoch: 34, eval loss: 0.7413494348526001, correct: 7401, total: 10000, acc = 0.7400999665260315
epoch time: 39.394805908203125
epoch: 35, train loss: 0.6561720742254841
epoch: 35, eval loss: 0.7245241671800613, correct: 7483, total: 10000, acc = 0.7482999563217163
epoch time: 39.34455704689026
epoch: 36, train loss: 0.6433814526820669
epoch: 36, eval loss: 0.7294039458036423, correct: 7483, total: 10000, acc = 0.7482999563217163
epoch time: 39.337549924850464
epoch: 37, train loss: 0.6366085136423305
epoch: 37, eval loss: 0.7336494833230972, correct: 7462, total: 10000, acc = 0.7461999654769897
epoch time: 39.338196754455566
epoch: 38, train loss: 0.6294400272320728
epoch: 38, eval loss: 0.719609409570694, correct: 7532, total: 10000, acc = 0.7531999945640564
epoch time: 39.33430027961731
epoch: 39, train loss: 0.6179663903859197
epoch: 39, eval loss: 0.7210630685091018, correct: 7507, total: 10000, acc = 0.7506999969482422
epoch time: 39.33643341064453
epoch: 40, train loss: 0.6102935781284254
epoch: 40, eval loss: 0.6994094282388688, correct: 7569, total: 10000, acc = 0.7568999528884888
epoch time: 39.38672637939453
epoch: 41, train loss: 0.5990810029360712
epoch: 41, eval loss: 0.7133035778999328, correct: 7550, total: 10000, acc = 0.7549999952316284
epoch time: 39.374757528305054
epoch: 42, train loss: 0.5964441865074391
epoch: 42, eval loss: 0.7060712993144989, correct: 7577, total: 10000, acc = 0.7576999664306641
epoch time: 39.4019033908844
epoch: 43, train loss: 0.5878602710305428
epoch: 43, eval loss: 0.7106044471263886, correct: 7580, total: 10000, acc = 0.7579999566078186
epoch time: 39.408252477645874
epoch: 44, train loss: 0.5797601254010687
epoch: 44, eval loss: 0.7093768745660782, correct: 7568, total: 10000, acc = 0.7567999958992004
epoch time: 39.40289378166199
epoch: 45, train loss: 0.5684604742089097
epoch: 45, eval loss: 0.7075642883777619, correct: 7612, total: 10000, acc = 0.761199951171875
epoch time: 39.35792422294617
epoch: 46, train loss: 0.5617077308041709
epoch: 46, eval loss: 0.707081851363182, correct: 7576, total: 10000, acc = 0.7576000094413757
epoch time: 39.37784481048584
epoch: 47, train loss: 0.5572127462649832
epoch: 47, eval loss: 0.7069586098194123, correct: 7606, total: 10000, acc = 0.7605999708175659
epoch time: 39.33794188499451
epoch: 48, train loss: 0.5519619742218329
epoch: 48, eval loss: 0.6923990368843078, correct: 7679, total: 10000, acc = 0.7678999900817871
epoch time: 39.39500594139099
epoch: 49, train loss: 0.5454421751961416
epoch: 49, eval loss: 0.7032370567321777, correct: 7626, total: 10000, acc = 0.7626000046730042
epoch time: 39.38570594787598
epoch: 50, train loss: 0.5419908360559114
epoch: 50, eval loss: 0.6949253618717194, correct: 7669, total: 10000, acc = 0.7669000029563904
epoch time: 39.334325551986694
epoch: 51, train loss: 0.5299993215166793
epoch: 51, eval loss: 0.6966427147388459, correct: 7654, total: 10000, acc = 0.7653999924659729
epoch time: 39.337984561920166
epoch: 52, train loss: 0.5282451452649369
epoch: 52, eval loss: 0.6932955116033555, correct: 7664, total: 10000, acc = 0.7663999795913696
epoch time: 39.34237813949585
epoch: 53, train loss: 0.5234840703862054
epoch: 53, eval loss: 0.6988086104393005, correct: 7654, total: 10000, acc = 0.7653999924659729
epoch time: 39.364726066589355
epoch: 54, train loss: 0.5139317989957576
epoch: 54, eval loss: 0.6950253814458847, correct: 7643, total: 10000, acc = 0.7642999887466431
epoch time: 39.40451097488403
epoch: 55, train loss: 0.5158528734226616
epoch: 55, eval loss: 0.6978882610797882, correct: 7672, total: 10000, acc = 0.7671999931335449
epoch time: 39.38926696777344
epoch: 56, train loss: 0.5082419429506574
epoch: 56, eval loss: 0.6909049898386002, correct: 7692, total: 10000, acc = 0.7691999673843384
epoch time: 39.42493271827698
epoch: 57, train loss: 0.5027476120360044
epoch: 57, eval loss: 0.6897687911987305, correct: 7695, total: 10000, acc = 0.7694999575614929
epoch time: 39.35954570770264
epoch: 58, train loss: 0.5053188776483342
epoch: 58, eval loss: 0.6899506479501725, correct: 7667, total: 10000, acc = 0.7666999697685242
epoch time: 39.44884634017944
epoch: 59, train loss: 0.4997740634241883
epoch: 59, eval loss: 0.687486720085144, correct: 7678, total: 10000, acc = 0.767799973487854
epoch time: 39.391881465911865
finish training
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment