Commit 45dc04f3 authored by thomwolf's avatar thomwolf
Browse files

tf model [WIP]

parent 24831477
...@@ -111,7 +111,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -111,7 +111,7 @@ class MultiHeadAttention(torch.nn.Module):
v = self.split_into_heads(v, batch_size) v = self.split_into_heads(v, batch_size)
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1] past_key, past_value = layer_past[0], layer_past[1]
k = torch.cat((past_key, k), dim=-1) k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2) v = torch.cat((past_value, v), dim=-2)
present = torch.stack((k, v)) present = torch.stack((k, v))
...@@ -167,25 +167,25 @@ class EncoderLayer(torch.nn.Module): ...@@ -167,25 +167,25 @@ class EncoderLayer(torch.nn.Module):
class CTRLPreTrainedModel(PreTrainedModel): class CTRLPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
"""
config_class = CTRLConfig
pretrained_model_archive_map = CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer"
def _init_weights(self, module):
""" Initialize the weights.
""" """
config_class = CTRLConfig if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
pretrained_model_archive_map = CTRL_PRETRAINED_MODEL_ARCHIVE_MAP # Slightly different from the TF version which uses truncated_normal for initialization
base_model_prefix = "transformer" # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
def _init_weights(self, module): if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
CTRL_START_DOCSTRING = r""" CTRL model was proposed in CTRL_START_DOCSTRING = r""" CTRL model was proposed in
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment