lstm.py 1.46 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
# import torch_mlu
from torch import dropout, nn
import torchvision.datasets as dsets
# from mlu_device import global_computing_device as g_com


# torch.manual_seed(1)

# Device configuration


class LSTMSimple(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, n_classes, device):
        super(LSTMSimple, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, dropout=0.5, batch_first=True)
        self.linear = nn.Linear(hidden_size, n_classes)
        self.device = device

    def forward(self, x):
        # x shape (batch, time_step, input_size)
        # out shape (batch, time_step, output_size)
        # h_n shape (n_layers, batch, hidden_size)
        # h_c shape (n_layers, batch, hidden_size)
        # 初始化hidden和memory cell参数
        hidden0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device)
        cell0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device)

        # forward propagate lstm
        out, (final_hidden, final_cell) = self.lstm(x, (hidden0, cell0))

        # 选取最后一个时刻的输出
        # out = self.fc(out[:, -1, :])
        # print(final_hidden[-1])
        return self.linear(out[:, -1, :])
        # return out




if __name__ == "__main__":
    # FitModel('600182')
    # data_form_watch()
    pass