README.md 4.79 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
# iTransformer
传统Transformer模型在一个时间步上查看所有特征,不能高效利用长程时序特征,iTransformer可以跨多个时间步查看一个特征,能同时预测多个指标。
## 论文
chenzk's avatar
v1.0.1  
chenzk committed
4
5
`iTransformer: Inverted Transformers Are Effective for Time Series Forecasting`
- https://arxiv.org/pdf/2310.06625
chenzk's avatar
v1.0  
chenzk committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

## 模型结构
采用标准的transformer decoder结构,对于backbone,无需修改transformer标准结构的代码即可实现本算法。
<div align=center>
    <img src="./doc/transformer.png"/>
</div>

## 算法原理
iTransformer通过简单地转置输入的形状来实现跨多个时间步查看一个特征,模型不是对输入的子序列进行令牌化,而是对整个输入序列进行令牌化,通过这种方式,注意力层可以专注于学习多元相关性,而前馈网络则负责对整个输入序列进行编码。
<div align=center>
    <img src="./doc/iTransformer.png"/>
</div>

## 环境配置
```
mv itransformer_pytorch iTransformer # 去框架名后缀
```

### Docker(方法一)
```
dcuai's avatar
dcuai committed
26
27
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
# <your IMAGE ID>为以上拉取的docker的镜像ID替换
chenzk's avatar
v1.0  
chenzk committed
28
29
30
31
32
33
34
35
36
37
38
39
40
docker run -it --shm-size=32G -v $PWD/iTransformer:/home/iTransformer -v /opt/hyhal:/opt/hyhal --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name itransformer <your IMAGE ID> bash
cd /home/iTransformer
pip install -r requirements.txt # requirements.txt
```
### Dockerfile(方法二)
```
cd iTransformer/docker
docker build --no-cache -t itransformer:latest .
docker run --shm-size=32G --name itransformer -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video -v $PWD/../../iTransformer:/home/iTransformer -it itransformer bash
# 若遇到Dockerfile启动的方式安装环境需要长时间等待,可注释掉里面的pip安装,启动容器后再安装python库:pip install -r requirements.txt。
```
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装:
chenzk's avatar
chenzk committed
41
- https://developer.sourcefind.cn/tool/
chenzk's avatar
v1.0  
chenzk committed
42
```
dcuai's avatar
dcuai committed
43
44
DTK驱动:dtk24.04.1
python:python3.10
chenzk's avatar
v1.0  
chenzk committed
45
46
47
48
49
50
51
52
53
54
55
56
torch:2.1.0
torchvision:0.16.0
```

`Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应。`

2、其它非特殊库参照requirements.txt安装
```
pip install -r requirements.txt # requirements.txt
```

## 数据集
zhougaofeng's avatar
zhougaofeng committed
57
本步骤说明采用ETT-small中的[`ETTm2`](./dataset/ETT-small/ETTm2.csv)
chenzk's avatar
v1.0  
chenzk committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

数据目录结构如下:
```
dataset/ETT-small
    ├── ETTh1.csv
    ├── ETTh2.csv
    ├── ETTm1.csv
    └── ETTm2.csv
```
更多资料可参考源项目的[`README_origin`](./README_origin.md)

## 训练
### 单机单卡
```
export HIP_VISIBLE_DEVICES=0
cd iTransformer
sh ./scripts/multivariate_forecasting/ETT/iTransformer_ETTm2_train.sh
```
更多资料可参考源项目的[`README_origin`](./README_origin.md)

## 推理
修改[`run.py`](./run.py)末尾几行如下:
```
# exp.test(setting, test=1)
exp.predict(setting, load=True)
```
```
export HIP_VISIBLE_DEVICES=0
sh ./scripts/multivariate_forecasting/ETT/iTransformer_ETTm2_infer.sh
# 默认按天预测, 故pred_len为96。
```

## result
chenzk's avatar
v1.0.5  
chenzk committed
91
`输入:`
chenzk's avatar
v1.0  
chenzk committed
92
93
94
95
96
97
98
```
2018-06-26 08:45:00,38.198001861572266,12.314000129699707,50.18000030517578,13.37600040435791,-11.53600025177002,-2.5910000801086426,42.03099822998047
2018-06-26 09:00:00,38.36600112915039,11.47599983215332,50.26100158691406,12.62600040435791,-11.53600025177002,-2.5910000801086426,42.69049835205078
...
2018-06-26 19:30:00,40.459999084472656,11.392000198364258,51.84199905395508,11.928999900817873,-11.53600025177002,-1.4179999828338623,45.54650115966797
2018-06-26 19:45:00,43.2239990234375,12.145999908447266,54.73699951171875,12.678999900817873,-11.53600025177002,-1.4179999828338623,45.32699966430664
```
chenzk's avatar
v1.0.5  
chenzk committed
99
`输出:`
chenzk's avatar
v1.0  
chenzk committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
```
# shape: (1, 96, 7)
[[[ 0.34952432  0.52950954  0.60350233  0.88908595 -1.1544497 0.14222175  1.547624  ]
  [ 0.33467558  0.5304026   0.5766822   0.8634169  -1.1414794 0.15061441  1.5383883 ]
  ...
  [ 0.38313037  0.55777836  0.58653885  0.8580381  -1.0596789 0.18568955  1.5027612 ]
  [ 0.3644999   0.55291736  0.57515836  0.8770145  -1.0512501 0.18641812  1.5099163 ]]]
```

### 精度
测试数据:[`ETTm2`](./dataset/ETT-small/ETTm2.csv)中划出一部分作验证集,推理框架:pytorch。

|  device   | train_loss |  mse | mae |
|:---------:|:----------:|:----------:|:----------:|
| DCU Z100L |   0.2107   |   0.1852   |   0.2718   |
| GPU V100S |   0.2107   |   0.1852   |   0.2718   |

## 应用场景
### 算法类别
`时序预测`
### 热点应用行业
`金融,运维,电商,制造,能源,医疗`
## 源码仓库及问题反馈
chenzk's avatar
chenzk committed
123
- http://developer.sourcefind.cn/codes/modelzoo/itransformer_pytorch.git
chenzk's avatar
v1.0  
chenzk committed
124
125
## 参考资料
- https://github.com/thuml/iTransformer.git