gnmt.py 3.64 KB
Newer Older
huchen's avatar
huchen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch.nn as nn

import seq2seq.data.config as config
from seq2seq.models.decoder import ResidualRecurrentDecoder
from seq2seq.models.encoder import ResidualRecurrentEncoder
from seq2seq.models.seq2seq_base import Seq2Seq
import torch
import time

class GNMT(Seq2Seq):
    """
    GNMT v2 model
    """
    def __init__(self, vocab_size, hidden_size=1024, num_layers=4, dropout=0.2,
                 batch_first=False, share_embedding=True, fusion=True):
        """
        Constructor for the GNMT v2 model.

        :param vocab_size: size of vocabulary (number of tokens)
        :param hidden_size: internal hidden size of the model
        :param num_layers: number of layers, applies to both encoder and
            decoder
        :param dropout: probability of dropout (in encoder and decoder)
        :param batch_first: if True the model uses (batch,seq,feature) tensors,
            if false the model uses (seq, batch, feature)
        :param share_embedding: if True embeddings are shared between encoder
            and decoder
        """

        super(GNMT, self).__init__(batch_first=batch_first)

        if share_embedding:
            embedder = nn.Embedding(vocab_size, hidden_size,
                                    padding_idx=config.PAD)
            nn.init.uniform_(embedder.weight.data, -0.1, 0.1)
        else:
            embedder = None

        self.embedder = embedder
        self.encoder = ResidualRecurrentEncoder(vocab_size, hidden_size,
                                                num_layers, dropout,
                                                batch_first, embedder)

        self.decoder = ResidualRecurrentDecoder(vocab_size, hidden_size,
                                                num_layers, dropout,
                                                batch_first, embedder,
                                                fusion=fusion)

    #def forward(self, input_encoder, input_enc_len, input_decoder):
    #    if self.embedder:
    #        input_encoder = self.embedder(input_encoder)
    #        input_decoder = self.embedder(input_decoder)

    #    context = self.encode(input_encoder, input_enc_len)
    #    input_enc_len = input_enc_len.to(input_encoder.device, non_blocking=True)
    #    context = (context, input_enc_len, None)
    #    output, _, _ = self.decode(input_decoder, context)

    #    return output
    def forward(self, input_encoder, input_enc_len, input_decoder):
        if self.embedder:
            input_encoder = self.embedder(input_encoder)
            input_decoder = self.embedder(input_decoder)
##aiss add for prof time
        torch.cuda.synchronize()  
        t1 = time.time()
        import pdb
        #pdb.set_trace()
        context = self.encode(input_encoder, input_enc_len)
        torch.cuda.synchronize()  
        t2 = time.time()
        time.sleep(120)
        input_enc_len = input_enc_len.to(input_encoder.device, non_blocking=True)
        torch.cuda.synchronize()  
        t5 = time.time()
        input_enc_len = input_enc_len.to(input_encoder.device, non_blocking=True)
        torch.cuda.synchronize()  
        t6 = time.time()
        context = (context, input_enc_len, None)
        torch.cuda.synchronize()  
        t3 = time.time()
        output, _, _ = self.decode(input_decoder, context)
        torch.cuda.synchronize()  
        t4 = time.time()
        print("encode time is ",(t2 - t1)*1000) 
        print("decode time is ",(t4 - t3)*1000) 
        print("process time is ",(t3 - t2)*1000)
        print("process copy time1 is ",(t5 - t2)*1000) 
        print("process copy time2 is ",(t6 - t5)*1000) 
        print("process concat time is ",(t3 - t6)*1000) 
      
        return output