import torch # import torch_mlu from torch import dropout, nn import torchvision.datasets as dsets # from mlu_device import global_computing_device as g_com # torch.manual_seed(1) # Device configuration class LSTMSimple(nn.Module): def __init__(self, input_size, hidden_size, num_layers, n_classes, device): super(LSTMSimple, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.lstm = nn.LSTM(input_size, hidden_size, num_layers, dropout=0.5, batch_first=True) self.linear = nn.Linear(hidden_size, n_classes) self.device = device def forward(self, x): # x shape (batch, time_step, input_size) # out shape (batch, time_step, output_size) # h_n shape (n_layers, batch, hidden_size) # h_c shape (n_layers, batch, hidden_size) # 初始化hidden和memory cell参数 hidden0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device) cell0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device) # forward propagate lstm out, (final_hidden, final_cell) = self.lstm(x, (hidden0, cell0)) # 选取最后一个时刻的输出 # out = self.fc(out[:, -1, :]) # print(final_hidden[-1]) return self.linear(out[:, -1, :]) # return out if __name__ == "__main__": # FitModel('600182') # data_form_watch() pass