Commit 31adbb24 authored by Rémi Louf's avatar Rémi Louf
Browse files

add class wireframes for Bert decoder

parent dda1adad
...@@ -331,6 +331,14 @@ class BertEncoderLayer(nn.Module): ...@@ -331,6 +331,14 @@ class BertEncoderLayer(nn.Module):
return outputs return outputs
class BertDecoderLayer(nn.Module):
def __init__(self, config):
raise NotImplementedError
def forward(self, hidden_state, encoder_output):
raise NotImplementedError
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertEncoder, self).__init__() super(BertEncoder, self).__init__()
...@@ -363,6 +371,14 @@ class BertEncoder(nn.Module): ...@@ -363,6 +371,14 @@ class BertEncoder(nn.Module):
return outputs # last-layer hidden state, (all hidden states), (all attentions) return outputs # last-layer hidden state, (all hidden states), (all attentions)
class BertDecoder(nn.Module):
def __init__(self, config):
raise NotImplementedError
def forward(self, encoder_output):
raise NotImplementedError
class BertPooler(nn.Module): class BertPooler(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertPooler, self).__init__() super(BertPooler, self).__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment